|
8 | 8 |
|
9 | 9 | namespace c10d { |
10 | 10 |
|
| 11 | +NCCLComm::NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {} |
| 12 | + |
| 13 | +NCCLComm::~NCCLComm() noexcept { |
| 14 | + // (kwen2501) Making CUDA/NCCL calls in this destructor can hit CUDA driver |
| 15 | + // shutdown error if CUDA context has exited first. Thus, we are not |
| 16 | + // destroying or aborting NCCL communicators here. We just detect and warn |
| 17 | + // about the risk of memory leak. Normally, a user would have called |
| 18 | + // `destroy_process_group` or `abort_process_group`, and such risk would be |
| 19 | + // avoided. |
| 20 | + LockType lock(mutex_); |
| 21 | + if (ncclComm_ && initialized_ && !aborted_) { |
| 22 | + TORCH_WARN_ONCE( |
| 23 | + "WARNING: NCCL communicator hasn't been destroyed. This may cause " |
| 24 | + "memory leaks. To avoid the risk, you can call `destroy_process_group` " |
| 25 | + "during normal exit or `_abort_process_group` when handling failures.") |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +// NOLINTNEXTLINE(*-noexcept-move-*) |
| 30 | +NCCLComm::NCCLComm(NCCLComm&& other) { |
| 31 | + // Using other's lock, as it reads other's states |
| 32 | + // Can not use this.mutex_, as this object is being constructed. |
| 33 | + LockType lock(other.mutex_); |
| 34 | + std::swap(ncclComm_, other.ncclComm_); |
| 35 | + std::swap(aborted_, other.aborted_); |
| 36 | + std::swap(ncclAsyncErr_, other.ncclAsyncErr_); |
| 37 | + std::swap(initialized_, other.initialized_); |
| 38 | + std::swap(nonBlocking_, other.nonBlocking_); |
| 39 | + std::swap(deviceIndex_, other.deviceIndex_); |
| 40 | +} |
| 41 | + |
| 42 | +ncclUniqueId NCCLComm::getNcclId() { |
| 43 | + return ncclId_; |
| 44 | +} |
| 45 | + |
| 46 | +std::shared_ptr<NCCLComm> NCCLComm::create( |
| 47 | + int numRanks, |
| 48 | + int rank, |
| 49 | + ncclUniqueId commId, |
| 50 | + at::DeviceIndex deviceIndex) { |
| 51 | + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); |
| 52 | + auto comm = std::make_shared<NCCLComm>(); |
| 53 | + C10D_NCCL_CHECK( |
| 54 | + ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), |
| 55 | + std::nullopt); |
| 56 | + comm->ncclId_ = commId; |
| 57 | + comm->rank_ = rank; |
| 58 | + comm->deviceIndex_ = deviceIndex; |
| 59 | + comm->initialized_ = true; |
| 60 | + // Old style comm is always blocking. |
| 61 | + comm->nonBlocking_ = false; |
| 62 | + return comm; |
| 63 | +} |
| 64 | + |
| 65 | +#ifdef NCCL_HAS_COMM_NONBLOCKING |
| 66 | +std::shared_ptr<NCCLComm> NCCLComm::create( |
| 67 | + int numRanks, |
| 68 | + int rank, |
| 69 | + ncclUniqueId commId, |
| 70 | + at::DeviceIndex deviceIndex, |
| 71 | + ncclConfig_t& config) { |
| 72 | + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); |
| 73 | + auto comm = std::make_shared<NCCLComm>(); |
| 74 | + comm->nonBlocking_ = config.blocking == 0; |
| 75 | + LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " |
| 76 | + << (comm->nonBlocking_ ? "nonblocking" : "blocking"); |
| 77 | + C10D_NCCL_CHECK_NONBLOCKING( |
| 78 | + ncclCommInitRankConfig( |
| 79 | + &(comm->ncclComm_), numRanks, commId, rank, &config), |
| 80 | + std::nullopt); |
| 81 | + comm->ncclId_ = commId; |
| 82 | + comm->rank_ = rank; |
| 83 | + comm->deviceIndex_ = deviceIndex; |
| 84 | + // Under blocking mode, comm is initialized immediately after NCCL init |
| 85 | + // returns; Under nonblocking mode, we check whether comm is initialized the |
| 86 | + // *next* time ncclComm_ is accessed. |
| 87 | + comm->initialized_ = !comm->nonBlocking_; |
| 88 | + return comm; |
| 89 | +} |
| 90 | +#endif |
| 91 | + |
11 | 92 | ncclComm_t NCCLComm::getNcclComm() { |
12 | 93 | LockType lock(mutex_); |
13 | 94 | if (aborted_) { |
@@ -56,6 +137,11 @@ void NCCLComm::waitReady(bool longInterval) { |
56 | 137 | } |
57 | 138 | } |
58 | 139 |
|
| 140 | +std::optional<std::string> NCCLComm::getNcclCommFailureReason() const { |
| 141 | + LockType lock(mutex_); |
| 142 | + return commFailureReason_; |
| 143 | +} |
| 144 | + |
59 | 145 | // TODO: why do we have `!defined(FBCODE_CAFFE2)` here? |
60 | 146 | #if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) |
61 | 147 | // last argument to split() API is not used to support |
@@ -147,6 +233,162 @@ void NCCLComm::destroy() { |
147 | 233 | aborted_ = true; |
148 | 234 | } |
149 | 235 |
|
| 236 | +void NCCLComm::abort(std::optional<std::string> commFailureReason) { |
| 237 | + LockType lock(mutex_); |
| 238 | + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); |
| 239 | +#ifdef ENABLE_NCCL_ERROR_CHECKING |
| 240 | + if (aborted_ && !initialized_) { |
| 241 | + // Should not abort twice. |
| 242 | + return; |
| 243 | + } |
| 244 | + |
| 245 | +#ifdef NCCL_HAS_COMM_REGISTER |
| 246 | + // Deregister all registered segments before aborting. |
| 247 | + for (auto& it : registeredSegmentHandles_) { |
| 248 | + void* handle = it.second; |
| 249 | + C10D_NCCL_CHECK( |
| 250 | + ::ncclCommDeregister(ncclComm_, handle), |
| 251 | + c10::str( |
| 252 | + "Failed to deregister segment handle ", |
| 253 | + handle, |
| 254 | + " on ncclComm_ ", |
| 255 | + ncclComm_)); |
| 256 | + } |
| 257 | + registeredSegmentHandles_.clear(); |
| 258 | +#endif |
| 259 | + |
| 260 | + // Set true failure reason if provided by ProcessGroupNCCL (e.g. work |
| 261 | + // timeout) |
| 262 | + commFailureReason_ = commFailureReason; |
| 263 | + LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: " |
| 264 | + << (commFailureReason ? *commFailureReason |
| 265 | + : "No abort reason provided."); |
| 266 | +#ifndef NCCL_HAS_COMM_NONBLOCKING |
| 267 | + C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); |
| 268 | +#else |
| 269 | + C10D_NCCL_CHECK_TIMEOUT( |
| 270 | + ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); |
| 271 | +#endif |
| 272 | + aborted_ = true; |
| 273 | + ncclComm_ = nullptr; |
| 274 | + |
| 275 | + // Set an appropriate error so that we avoid using the communicator. |
| 276 | + if (ncclAsyncErr_ == ncclSuccess) { |
| 277 | + ncclAsyncErr_ = ncclSystemError; |
| 278 | + } |
| 279 | +#else |
| 280 | + // This is a NOOP, if error checks are disabled. |
| 281 | + return; |
| 282 | +#endif |
| 283 | +} |
| 284 | + |
| 285 | +bool NCCLComm::isInitialized() const { |
| 286 | + LockType lock(mutex_); |
| 287 | + return initialized_; |
| 288 | +} |
| 289 | + |
| 290 | +bool NCCLComm::isAborted() const { |
| 291 | + LockType lock(mutex_); |
| 292 | + return aborted_; |
| 293 | +} |
| 294 | + |
| 295 | +uint64_t NCCLComm::getCommSplitCounter() const { |
| 296 | + return ncclCommSplitCounter_; |
| 297 | +} |
| 298 | + |
| 299 | +ncclResult_t NCCLComm::checkForNcclError() { |
| 300 | + LockType lock(mutex_); |
| 301 | +#ifdef ENABLE_NCCL_ERROR_CHECKING |
| 302 | + if (ncclAsyncErr_ != ncclSuccess) { |
| 303 | + return ncclAsyncErr_; |
| 304 | + } |
| 305 | + C10D_NCCL_CHECK( |
| 306 | + ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); |
| 307 | + return ncclAsyncErr_; |
| 308 | +#else |
| 309 | + // Always return success, if error checks are disabled. |
| 310 | + return ncclSuccess; |
| 311 | +#endif |
| 312 | +} |
| 313 | + |
| 314 | +ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) { |
| 315 | + LockType lock(mutex_); |
| 316 | +#ifdef NCCL_HAS_COMM_REGISTER |
| 317 | + // We register only segments from cache allocator |
| 318 | + // which are guaranteed to be with disjoint addr ranges. Thus, a ptr always |
| 319 | + // maps to a unique handle and should not be registered before the current |
| 320 | + // ptr is deregistered and freed. |
| 321 | + TORCH_CHECK( |
| 322 | + registeredSegmentHandles_.count(ptr) == 0, |
| 323 | + "Segment with ptr ", |
| 324 | + ptr, |
| 325 | + " has already been registered on ncclComm_ ", |
| 326 | + ncclComm_); |
| 327 | + |
| 328 | + void* handle = nullptr; |
| 329 | + // Use getNcclComm to make sure comm is ready before calling nccl APIs |
| 330 | + auto comm = getNcclComm(); |
| 331 | + C10D_NCCL_CHECK( |
| 332 | + ncclCommRegister(comm, ptr, size, &handle), |
| 333 | + c10::str( |
| 334 | + "Failed to register segment with ptr ", |
| 335 | + ptr, |
| 336 | + ", size ", |
| 337 | + size, |
| 338 | + " on ncclComm_ ", |
| 339 | + comm)); |
| 340 | + registeredSegmentHandles_[ptr] = handle; |
| 341 | + return ncclSuccess; |
| 342 | +#else |
| 343 | + return ncclInvalidUsage; |
| 344 | +#endif |
| 345 | +} |
| 346 | + |
| 347 | +ncclResult_t NCCLComm::deregisterSegment(void* ptr) { |
| 348 | + LockType lock(mutex_); |
| 349 | +#ifdef NCCL_HAS_COMM_REGISTER |
| 350 | + TORCH_CHECK( |
| 351 | + registeredSegmentHandles_.count(ptr) == 1, |
| 352 | + "Segment with ptr ", |
| 353 | + ptr, |
| 354 | + " is not registered on ncclComm_ ", |
| 355 | + ncclComm_); |
| 356 | + |
| 357 | + void* handle = registeredSegmentHandles_[ptr]; |
| 358 | + // Use getNcclComm to make sure comm is ready before calling nccl APIs |
| 359 | + auto comm = getNcclComm(); |
| 360 | + C10D_NCCL_CHECK( |
| 361 | + ncclCommDeregister(comm, handle), |
| 362 | + c10::str( |
| 363 | + "Failed to deregister segment handle ", |
| 364 | + handle, |
| 365 | + ", with ptr ", |
| 366 | + ptr, |
| 367 | + " on ncclComm_ ", |
| 368 | + comm)); |
| 369 | + registeredSegmentHandles_.erase(ptr); |
| 370 | + return ncclSuccess; |
| 371 | +#else |
| 372 | + return ncclInvalidUsage; |
| 373 | +#endif |
| 374 | +} |
| 375 | + |
| 376 | +std::string NCCLComm::repr() const { |
| 377 | + return c10::str((void*)ncclComm_); |
| 378 | +} |
| 379 | + |
| 380 | +#if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) |
| 381 | +std::unordered_map<std::string, std::string> NCCLComm::ncclCommDump() { |
| 382 | + std::unordered_map<std::string, std::string> dump; |
| 383 | + if (isAborted()) { |
| 384 | + LOG(INFO) << "Communicator was aborted before trying to dump its state."; |
| 385 | + return dump; |
| 386 | + } |
| 387 | + C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), std::nullopt); |
| 388 | + return dump; |
| 389 | +} |
| 390 | +#endif |
| 391 | + |
150 | 392 | std::string getNcclVersion() { |
151 | 393 | static c10::once_flag ncclGetVersionFlag; |
152 | 394 | static std::string versionString; |
|
0 commit comments