Skip to content

Commit 50ac546

Browse files
committed
[c10d] differentiate timeout errors from nccl errors
Summary: It's important for c10d to differentiate different reasons of watchdog failures. E.g, timeout vs nccl errors, and let users to handle the errors depends on the type of error Test Plan: UT Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 20af56d commit 50ac546

File tree

4 files changed

+44
-39
lines changed

4 files changed

+44
-39
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2681,7 +2681,7 @@ def assert_fut_success(fut):
26812681
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
26822682
work.wait()
26832683
result = work.get_future_result().wait()
2684-
self.assertEqual(WorkResult(result), WorkResult.FAILURE)
2684+
self.assertEqual(WorkResult(result), WorkResult.COMM_ERROR)
26852685

26862686
if prev_nccl_async_error_handling is not None:
26872687
os.environ[

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,27 +1859,60 @@ void ProcessGroupNCCL::watchdogHandler() {
18591859
// aborted, So cannot check exception based on them. But watchdog needs to
18601860
// finish the check for the works that have already been enqueued to
18611861
// workMetaList_
1862+
1863+
// check NCCL errors first
18621864
if (!terminateProcessGroup_.load()) {
18631865
work.checkAndSetException();
18641866
}
1865-
bool timedOut = work.checkTimeout();
1866-
1867-
// If work hits an exception (either an error or timeout)
18681867
if (work.exception()) {
1868+
// log as soon as exception is detected
1869+
LOG(ERROR) << c10::str(
1870+
logPrefix(),
1871+
"NCCL error is detected by watchdog at work: ",
1872+
work.seq_,
1873+
", last enqueued NCCL work: ",
1874+
pgStatus_->lastEnqueuedSeq,
1875+
", last completed NCCL work: ",
1876+
pgStatus_->lastCompletedSeq,
1877+
".");
18691878
if (work.futureWorkResult_ && !work.futureWorkResult_->completed()) {
18701879
work.futureWorkResult_->markCompleted(
1871-
at::IValue(static_cast<uint8_t>(WorkResult::FAILURE)));
1880+
at::IValue(static_cast<uint8_t>(WorkResult::COMM_ERROR)));
18721881
}
1873-
// log as soon as exception is detected
1882+
} else if (work.checkTimeout()) {
18741883
LOG(ERROR) << c10::str(
18751884
logPrefix(),
1876-
"Exception (either an error or timeout) detected by watchdog at work: ",
1885+
"Work timeout is detected by watchdog at work: ",
18771886
work.seq_,
18781887
", last enqueued NCCL work: ",
18791888
pgStatus_->lastEnqueuedSeq,
18801889
", last completed NCCL work: ",
18811890
pgStatus_->lastCompletedSeq,
18821891
".");
1892+
if (work.futureWorkResult_ && !work.futureWorkResult_->completed()) {
1893+
work.futureWorkResult_->markCompleted(
1894+
at::IValue(static_cast<uint8_t>(WorkResult::TIMEOUT)));
1895+
}
1896+
// Report desync state in case of timeout
1897+
if (desyncDebug_) {
1898+
try {
1899+
collectiveDebugInfoMode_.store(true);
1900+
auto desyncMsg = getNCCLWatchdogDebugInfo();
1901+
LOG(ERROR) << logPrefix() << desyncMsg;
1902+
} catch (const std::exception& e) {
1903+
LOG(ERROR) << logPrefix()
1904+
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
1905+
<< " Please file an issue. Error: " << e.what();
1906+
} catch (...) {
1907+
LOG(ERROR)
1908+
<< logPrefix()
1909+
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
1910+
<< " Please file an issue.";
1911+
}
1912+
}
1913+
}
1914+
// If work hits an exception (either an error or timeout)
1915+
if (work.exception()) {
18831916
// try to notify other ranks via global TCPStore to dump the flight
18841917
// recorder when a collective timeout or exception happens. Flight
18851918
// recorder behavior is independent of desync Debug.
@@ -1919,36 +1952,6 @@ void ProcessGroupNCCL::watchdogHandler() {
19191952
// rank
19201953
abortComms();
19211954
}
1922-
1923-
// Report desync state in case of timeout
1924-
if (timedOut) {
1925-
LOG(ERROR) << c10::str(
1926-
logPrefix(),
1927-
"Timeout at NCCL work: ",
1928-
work.seq_,
1929-
", last enqueued NCCL work: ",
1930-
pgStatus_->lastEnqueuedSeq,
1931-
", last completed NCCL work: ",
1932-
pgStatus_->lastCompletedSeq,
1933-
".");
1934-
if (desyncDebug_) {
1935-
try {
1936-
collectiveDebugInfoMode_.store(true);
1937-
auto desyncMsg = getNCCLWatchdogDebugInfo();
1938-
LOG(ERROR) << logPrefix() << desyncMsg;
1939-
} catch (const std::exception& e) {
1940-
LOG(ERROR)
1941-
<< logPrefix()
1942-
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
1943-
<< " Please file an issue. Error: " << e.what();
1944-
} catch (...) {
1945-
LOG(ERROR)
1946-
<< logPrefix()
1947-
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
1948-
<< " Please file an issue.";
1949-
}
1950-
}
1951-
}
19521955
// Throw exception
19531956
work.handleException(asyncErrorHandling_);
19541957
}

torch/csrc/distributed/c10d/Work.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ enum class OpType : std::uint8_t {
3737
// TODO: support different types of failures/errors
3838
enum class WorkResult : std::uint8_t {
3939
SUCCESS = 0,
40-
FAILURE = 1,
40+
TIMEOUT = 1,
41+
COMM_ERROR = 2,
4142
UNKNOWN = 100,
4243
};
4344

torch/csrc/distributed/c10d/init.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2916,7 +2916,8 @@ Example::
29162916

29172917
py::enum_<::c10d::WorkResult>(module, "WorkResult")
29182918
.value("SUCCESS", ::c10d::WorkResult::SUCCESS)
2919-
.value("FAILURE", ::c10d::WorkResult::FAILURE)
2919+
.value("TIMEOUT", ::c10d::WorkResult::TIMEOUT)
2920+
.value("COMM_ERROR", ::c10d::WorkResult::COMM_ERROR)
29202921
.value("UNKNOWN", ::c10d::WorkResult::UNKNOWN);
29212922

29222923
py::class_<::c10d::WorkInfo, std::shared_ptr<::c10d::WorkInfo>>(

0 commit comments

Comments
 (0)