-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
#21449 skips a pointer if the block is not found, but this also gives up the opportunity to raise an error on missing blocks if the ptr does come from the CUDA caching allocator. @ezyang proposed a way to address this issue by checking the deleter on the pointer.
you determine if a pointer was produced by the CUDA caching allocator by attempting to look it up in the block map; if it's not in the block map, it's not a CUDA caching allocator pointer. My point is that another way to test if it's produced by the CUDA caching allocator is by checking if the deleter for the DataPtr in question is for the CUDA caching allocator (you can perform this test using cast_context).
Update
Since this is turning into a bootcamp issue, let me add more context here.
Background
The API recordStream records that some CUDA stream is using the given CUDA memory block, which prevents the block from been freed too early. If you look at the free() code in CudaCachingAllocator, you will find that instead of free the block right away, it inserts an event to the stream, and it will only free this block after the event has fired.
pytorch/c10/cuda/CUDACachingAllocator.cpp
Lines 363 to 367 in 1610ea8
| if (!block->stream_uses.empty()) { | |
| insert_events(block); | |
| } else { | |
| free_block(block); | |
| } |
More specifically, it will lazily process those events the next time some one tries to grab a block from the caching allocator:
pytorch/c10/cuda/CUDACachingAllocator.cpp
Lines 774 to 781 in 1610ea8
| void process_events() | |
| { | |
| // Process outstanding cudaEvents. Events that are completed are removed | |
| // from the queue, and the 'event_count' for the corresponding allocation | |
| // is decremented. Stops at the first event which has not been completed. | |
| // Since events on different devices or streams may occur out of order, | |
| // the processing of some events may be delayed. | |
| while (!cuda_events.empty()) { |
Currently, the recordStream API is mainly used in ProcessGroupNCCL to prevent deleting input and output blocks before finishing collective communications:
pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp
Lines 740 to 745 in 1610ea8
| for (size_t j = 0; j < inputTensors[0].size(); ++j) { | |
| // See [Sync Streams]. | |
| c10::cuda::CUDACachingAllocator::recordStream( | |
| inputTensors[i][j].storage().data(), ncclStreams[i]); | |
| inputFlattened[i][j].copy_(inputTensors[i][j], true); |
Problem Statement
#20658 and #21449 skips the pointer check in recordStream when the pointer if nullptr, which can happen when the block is created by a different process (hence different caching allocator instance) or when the ptr comes from an empty tensor.
This fixes errors in some use cases, but also lost the opportunity of checking whether the pointer is valid or not if it indeed is created by the caching allocator on the current process.
Proposed Solution
@ezyang (please correct me if I am wrong) recommended to use cast_context to check the deleter on the pointer, so that we can know whether the pointer does come from a caching allocator.
Lines 52 to 55 in 1610ea8
| template <typename T> | |
| T* cast_context(DeleterFnPtr expected_deleter) const { | |
| return ptr_.cast_context<T>(expected_deleter); | |
| } |
This means that, all call sites of recordStream() needs to acquire the DataPtr from the storage instead of directly getting the void*:
Lines 102 to 108 in 1610ea8
| at::DataPtr& data_ptr() { | |
| return storage_impl_->data_ptr(); | |
| } | |
| const at::DataPtr& data_ptr() const { | |
| return storage_impl_->data_ptr(); | |
| } |
Then, we need to check if it needs to check whether the DataPtr can be casted using CudaCachingDeleter.
Question:
@ezyang @colesbury What if the data pointer comes from a caching allocator of a different process? It will pass the cast check, but the caching allocator on this process still cannot find it, right? I realized the function pointer will have different value if they are from different processes.
cc @ngimel @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528