Skip to content

Check DataPtr's deleter to determine if it is allocated by CUDA in record_stream #27405

@mrshenli

Description

@mrshenli

#21449 skips a pointer if the block is not found, but this also gives up the opportunity to raise an error on missing blocks if the ptr does come from the CUDA caching allocator. @ezyang proposed a way to address this issue by checking the deleter on the pointer.

you determine if a pointer was produced by the CUDA caching allocator by attempting to look it up in the block map; if it's not in the block map, it's not a CUDA caching allocator pointer. My point is that another way to test if it's produced by the CUDA caching allocator is by checking if the deleter for the DataPtr in question is for the CUDA caching allocator (you can perform this test using cast_context).

Update

Since this is turning into a bootcamp issue, let me add more context here.

Background

The API recordStream records that some CUDA stream is using the given CUDA memory block, which prevents the block from been freed too early. If you look at the free() code in CudaCachingAllocator, you will find that instead of free the block right away, it inserts an event to the stream, and it will only free this block after the event has fired.

if (!block->stream_uses.empty()) {
insert_events(block);
} else {
free_block(block);
}

More specifically, it will lazily process those events the next time some one tries to grab a block from the caching allocator:

void process_events()
{
// Process outstanding cudaEvents. Events that are completed are removed
// from the queue, and the 'event_count' for the corresponding allocation
// is decremented. Stops at the first event which has not been completed.
// Since events on different devices or streams may occur out of order,
// the processing of some events may be delayed.
while (!cuda_events.empty()) {

Currently, the recordStream API is mainly used in ProcessGroupNCCL to prevent deleting input and output blocks before finishing collective communications:

for (size_t j = 0; j < inputTensors[0].size(); ++j) {
// See [Sync Streams].
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors[i][j].storage().data(), ncclStreams[i]);
inputFlattened[i][j].copy_(inputTensors[i][j], true);

Problem Statement

#20658 and #21449 skips the pointer check in recordStream when the pointer if nullptr, which can happen when the block is created by a different process (hence different caching allocator instance) or when the ptr comes from an empty tensor.

This fixes errors in some use cases, but also lost the opportunity of checking whether the pointer is valid or not if it indeed is created by the caching allocator on the current process.

Proposed Solution

@ezyang (please correct me if I am wrong) recommended to use cast_context to check the deleter on the pointer, so that we can know whether the pointer does come from a caching allocator.

template <typename T>
T* cast_context(DeleterFnPtr expected_deleter) const {
return ptr_.cast_context<T>(expected_deleter);
}

This means that, all call sites of recordStream() needs to acquire the DataPtr from the storage instead of directly getting the void*:

pytorch/c10/core/Storage.h

Lines 102 to 108 in 1610ea8

at::DataPtr& data_ptr() {
return storage_impl_->data_ptr();
}
const at::DataPtr& data_ptr() const {
return storage_impl_->data_ptr();
}

Then, we need to check if it needs to check whether the DataPtr can be casted using CudaCachingDeleter.

Question:

@ezyang @colesbury What if the data pointer comes from a caching allocator of a different process? It will pass the cast check, but the caching allocator on this process still cannot find it, right? I realized the function pointer will have different value if they are from different processes.

cc @ngimel @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528

Metadata

Metadata

Assignees

Labels

module: bootcampWe plan to do a full writeup on the issue, and then get someone to do it for onboardingmodule: cudaRelated to torch.cuda, and CUDA support in generaloncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions