Skip to content

Commit cb354f8

Browse files
kwen2501pytorchmergebot
authored andcommitted
1 parent 06075d3 commit cb354f8

File tree

2 files changed

+258
-212
lines changed

2 files changed

+258
-212
lines changed

torch/csrc/distributed/c10d/NCCLUtils.cpp

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,87 @@
88

99
namespace c10d {
1010

11+
NCCLComm::NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {}
12+
13+
NCCLComm::~NCCLComm() noexcept {
14+
// (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver
15+
// shutdown error if CUDA context has exited first. Thus, we are not
16+
// destroying or aborting NCCL communicators here. We just detect and warn
17+
// about the risk of memory leak. Normally, a user would have called
18+
// `destroy_process_group` or `abort_process_group`, and such risk would be
19+
// avoided.
20+
LockType lock(mutex_);
21+
if (ncclComm_ && initialized_ && !aborted_) {
22+
TORCH_WARN_ONCE(
23+
"WARNING: NCCL communicator hasn't been destroyed. This may cause "
24+
"memory leaks. To avoid the risk, you can call `destroy_process_group` "
25+
"during normal exit or `_abort_process_group` when handling failures.")
26+
}
27+
}
28+
29+
// NOLINTNEXTLINE(*-noexcept-move-*)
30+
NCCLComm::NCCLComm(NCCLComm&& other) {
31+
// Using other's lock, as it reads other's states
32+
// Can not use this.mutex_, as this object is being constructed.
33+
LockType lock(other.mutex_);
34+
std::swap(ncclComm_, other.ncclComm_);
35+
std::swap(aborted_, other.aborted_);
36+
std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
37+
std::swap(initialized_, other.initialized_);
38+
std::swap(nonBlocking_, other.nonBlocking_);
39+
std::swap(deviceIndex_, other.deviceIndex_);
40+
}
41+
42+
ncclUniqueId NCCLComm::getNcclId() {
43+
return ncclId_;
44+
}
45+
46+
std::shared_ptr<NCCLComm> NCCLComm::create(
47+
int numRanks,
48+
int rank,
49+
ncclUniqueId commId,
50+
at::DeviceIndex deviceIndex) {
51+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
52+
auto comm = std::make_shared<NCCLComm>();
53+
C10D_NCCL_CHECK(
54+
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
55+
std::nullopt);
56+
comm->ncclId_ = commId;
57+
comm->rank_ = rank;
58+
comm->deviceIndex_ = deviceIndex;
59+
comm->initialized_ = true;
60+
// Old style comm is always blocking.
61+
comm->nonBlocking_ = false;
62+
return comm;
63+
}
64+
65+
#ifdef NCCL_HAS_COMM_NONBLOCKING
66+
std::shared_ptr<NCCLComm> NCCLComm::create(
67+
int numRanks,
68+
int rank,
69+
ncclUniqueId commId,
70+
at::DeviceIndex deviceIndex,
71+
ncclConfig_t& config) {
72+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex);
73+
auto comm = std::make_shared<NCCLComm>();
74+
comm->nonBlocking_ = config.blocking == 0;
75+
LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: "
76+
<< (comm->nonBlocking_ ? "nonblocking" : "blocking");
77+
C10D_NCCL_CHECK_NONBLOCKING(
78+
ncclCommInitRankConfig(
79+
&(comm->ncclComm_), numRanks, commId, rank, &config),
80+
std::nullopt);
81+
comm->ncclId_ = commId;
82+
comm->rank_ = rank;
83+
comm->deviceIndex_ = deviceIndex;
84+
// Under blocking mode, comm is initialized immediately after NCCL init
85+
// returns; Under nonblocking mode, we check whether comm is initialized the
86+
// *next* time ncclComm_ is accessed.
87+
comm->initialized_ = !comm->nonBlocking_;
88+
return comm;
89+
}
90+
#endif
91+
1192
ncclComm_t NCCLComm::getNcclComm() {
1293
LockType lock(mutex_);
1394
if (aborted_) {
@@ -56,6 +137,11 @@ void NCCLComm::waitReady(bool longInterval) {
56137
}
57138
}
58139

140+
std::optional<std::string> NCCLComm::getNcclCommFailureReason() const {
141+
LockType lock(mutex_);
142+
return commFailureReason_;
143+
}
144+
59145
// TODO: why do we have `!defined(FBCODE_CAFFE2)` here?
60146
#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2)
61147
// last argument to split() API is not used to support
@@ -147,6 +233,162 @@ void NCCLComm::destroy() {
147233
aborted_ = true;
148234
}
149235

236+
void NCCLComm::abort(std::optional<std::string> commFailureReason) {
237+
LockType lock(mutex_);
238+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
239+
#ifdef ENABLE_NCCL_ERROR_CHECKING
240+
if (aborted_ && !initialized_) {
241+
// Should not abort twice.
242+
return;
243+
}
244+
245+
#ifdef NCCL_HAS_COMM_REGISTER
246+
// Deregister all registered segments before aborting.
247+
for (auto& it : registeredSegmentHandles_) {
248+
void* handle = it.second;
249+
C10D_NCCL_CHECK(
250+
::ncclCommDeregister(ncclComm_, handle),
251+
c10::str(
252+
"Failed to deregister segment handle ",
253+
handle,
254+
" on ncclComm_ ",
255+
ncclComm_));
256+
}
257+
registeredSegmentHandles_.clear();
258+
#endif
259+
260+
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
261+
// timeout)
262+
commFailureReason_ = commFailureReason;
263+
LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
264+
<< (commFailureReason ? *commFailureReason
265+
: "No abort reason provided.");
266+
#ifndef NCCL_HAS_COMM_NONBLOCKING
267+
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
268+
#else
269+
C10D_NCCL_CHECK_TIMEOUT(
270+
::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
271+
#endif
272+
aborted_ = true;
273+
ncclComm_ = nullptr;
274+
275+
// Set an appropriate error so that we avoid using the communicator.
276+
if (ncclAsyncErr_ == ncclSuccess) {
277+
ncclAsyncErr_ = ncclSystemError;
278+
}
279+
#else
280+
// This is a NOOP, if error checks are disabled.
281+
return;
282+
#endif
283+
}
284+
285+
bool NCCLComm::isInitialized() const {
286+
LockType lock(mutex_);
287+
return initialized_;
288+
}
289+
290+
bool NCCLComm::isAborted() const {
291+
LockType lock(mutex_);
292+
return aborted_;
293+
}
294+
295+
uint64_t NCCLComm::getCommSplitCounter() const {
296+
return ncclCommSplitCounter_;
297+
}
298+
299+
ncclResult_t NCCLComm::checkForNcclError() {
300+
LockType lock(mutex_);
301+
#ifdef ENABLE_NCCL_ERROR_CHECKING
302+
if (ncclAsyncErr_ != ncclSuccess) {
303+
return ncclAsyncErr_;
304+
}
305+
C10D_NCCL_CHECK(
306+
ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
307+
return ncclAsyncErr_;
308+
#else
309+
// Always return success, if error checks are disabled.
310+
return ncclSuccess;
311+
#endif
312+
}
313+
314+
ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) {
315+
LockType lock(mutex_);
316+
#ifdef NCCL_HAS_COMM_REGISTER
317+
// We register only segments from cache allocator
318+
// which are guaranteed to be with disjoint addr ranges. Thus, a ptr always
319+
// maps to a unique handle and should not be registered before the current
320+
// ptr is deregistered and freed.
321+
TORCH_CHECK(
322+
registeredSegmentHandles_.count(ptr) == 0,
323+
"Segment with ptr ",
324+
ptr,
325+
" has already been registered on ncclComm_ ",
326+
ncclComm_);
327+
328+
void* handle = nullptr;
329+
// Use getNcclComm to make sure comm is ready before calling nccl APIs
330+
auto comm = getNcclComm();
331+
C10D_NCCL_CHECK(
332+
ncclCommRegister(comm, ptr, size, &handle),
333+
c10::str(
334+
"Failed to register segment with ptr ",
335+
ptr,
336+
", size ",
337+
size,
338+
" on ncclComm_ ",
339+
comm));
340+
registeredSegmentHandles_[ptr] = handle;
341+
return ncclSuccess;
342+
#else
343+
return ncclInvalidUsage;
344+
#endif
345+
}
346+
347+
ncclResult_t NCCLComm::deregisterSegment(void* ptr) {
348+
LockType lock(mutex_);
349+
#ifdef NCCL_HAS_COMM_REGISTER
350+
TORCH_CHECK(
351+
registeredSegmentHandles_.count(ptr) == 1,
352+
"Segment with ptr ",
353+
ptr,
354+
" is not registered on ncclComm_ ",
355+
ncclComm_);
356+
357+
void* handle = registeredSegmentHandles_[ptr];
358+
// Use getNcclComm to make sure comm is ready before calling nccl APIs
359+
auto comm = getNcclComm();
360+
C10D_NCCL_CHECK(
361+
ncclCommDeregister(comm, handle),
362+
c10::str(
363+
"Failed to deregister segment handle ",
364+
handle,
365+
", with ptr ",
366+
ptr,
367+
" on ncclComm_ ",
368+
comm));
369+
registeredSegmentHandles_.erase(ptr);
370+
return ncclSuccess;
371+
#else
372+
return ncclInvalidUsage;
373+
#endif
374+
}
375+
376+
std::string NCCLComm::repr() const {
377+
return c10::str((void*)ncclComm_);
378+
}
379+
380+
#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP)
381+
std::unordered_map<std::string, std::string> NCCLComm::ncclCommDump() {
382+
std::unordered_map<std::string, std::string> dump;
383+
if (isAborted()) {
384+
LOG(INFO) << "Communicator was aborted before trying to dump its state.";
385+
return dump;
386+
}
387+
C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt);
388+
return dump;
389+
}
390+
#endif
391+
150392
std::string getNcclVersion() {
151393
static c10::once_flag ncclGetVersionFlag;
152394
static std::string versionString;

0 commit comments

Comments
 (0)