Skip to content

Conversation

@shuqiangzhang
Copy link
Contributor

@shuqiangzhang shuqiangzhang commented Nov 8, 2024

Stack from ghstack (oldest at bottom):

Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Note it is critical for all ranks in the PG to be notified about an error as soon as it occurs, so we introduce an errorType of REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a local error) to all other ranks in the PG, the broadcast is done through TCPStore currently

Tags:

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.
Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140087

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 1a1049d with merge base e474f0d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Nov 8, 2024
shuqiangzhang added a commit that referenced this pull request Nov 8, 2024
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.
Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8025d59
Pull Request resolved: #140087
@shuqiangzhang shuqiangzhang marked this pull request as draft November 8, 2024 02:02
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.
Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@shuqiangzhang shuqiangzhang changed the title [PGNCCL] Add an API to check the healthiness of each PG [PGNCCL] Add an API to get the status/error code of each PG Nov 8, 2024
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Nov 8, 2024
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.
Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 548ff56
Pull Request resolved: #140087
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Nov 8, 2024
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.
Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f4ea1c5
Pull Request resolved: #140087
@shuqiangzhang shuqiangzhang requested review from c-p-i-o, d4l3k, fduwjj, kwen2501 and wconstab and removed request for kwen2501 November 8, 2024 22:46
@shuqiangzhang shuqiangzhang marked this pull request as ready for review November 8, 2024 22:47
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Note it is critical for all ranks in the PG to be notified about an error as soon as it occurs, so we introduce an errorType of REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a local error) to all other ranks in the PG, the broadcast is done through TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Nov 9, 2024
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.
Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 7b55b9e
Pull Request resolved: #140087
class ErrorType(Enum):
NO_ERROR = ...
TIMEOUT = ...
NCCL_ERROR = ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generalize NCCL_ERROR to BACKEND_ERROR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making it COMM_ERROR, similar to the workResult errors

Comment on lines 151 to 159
enum TORCH_API ErrorType {
NO_ERROR = 0,
TIMEOUT = 1,
NCCL_ERROR = 2,
// TODO, do we need to distinguish between remote timeout or remote NCCL
// errors?
REMOTE_ERROR = 3,
};

Copy link
Collaborator

@kwen2501 kwen2501 Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can define this API one level up in Backend.hpp?

Comment on lines 771 to 772
ErrorType getError();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion for uplifting. For non-NCCL backends, we can leave it as unimplemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was debating about this with myself too, can make it in any backend

Copy link
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me.

Comment on lines +1956 to +1959
void ProcessGroupNCCL::broadcastSignal(
c10::intrusive_ptr<Store>& store,
const std::string& signal,
int srcRank) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the signature, it looks like we can make this a util function independent of ProcessGroupNCCL? Not a hard requirement though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will keep it in PGNCCL first, we can move it later if other backends need it

Comment on lines +1986 to +1992
try {
auto vec = store->get(std::string(signal));
TORCH_CHECK_WITH(
DistBackendError,
vec.size() == sizeof(int),
"Invalid size for the timeout rank ID");
std::memcpy(&srcRank, vec.data(), vec.size());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Care to give a comment for this block?

Comment on lines 2089 to 2102
// Check and set remote error if it has not been set before
if (getError() == ErrorType::NO_ERROR) {
int remoteErrorRank =
getSignalSrcRank(store_, std::string(REMOTE_ERROR_SIGNAL));
if (remoteErrorRank != -1) {
std::lock_guard<std::mutex> lock(errorMutex_);
error_ = ErrorType::REMOTE_ERROR;
LOG(ERROR) << c10::str(
logPrefix(),
" remote error detected by watchdog thread from rank: ",
remoteErrorRank);
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: modularize this as checkRemoteError?

Comment on lines 2115 to 2119
if (work.exception() && getError() == ErrorType::NO_ERROR) {
// set the error to the first error found
std::lock_guard<std::mutex> lock(errorMutex_);
error_ = ErrorType::NCCL_ERROR;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lock -> read -> write to guarantee atomicity.
Or, we can use std::atomic.

Comment on lines 128 to 130
constexpr const char* EXCEPTION_DUMP = "exception_dump";

constexpr const char* REMOTE_ERROR_SIGNAL = "remote_error";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: storeDumpKey, storeErrorSignalKey

Comment on lines 2091 to 2092
int remoteErrorRank =
getSignalSrcRank(store_, std::string(REMOTE_ERROR_SIGNAL));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked whether getSignalSrcRank could be blocking or not. Let's be careful when putting potentially blocking call in watchdog.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's non blocking in a sense, we first 'check' if key exists which is nonblocking, and then read/get the value only if the key exists. Otherwise, if we try to get the kv directly, it could be blocking.

Comment on lines 2145 to 2147
// broadcast remote error signal to all other ranks in this specific PG.
broadcastSignal(store_, std::string(REMOTE_ERROR_SIGNAL), rank_);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked whether broadcastSignal could be blocking or not. Let's be careful when putting potentially blocking call in watchdog.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same functionality was in watchdog thread in the original code

Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Note it is critical for all ranks in the PG to be notified about an error as soon as it occurs, so we introduce an errorType of REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a local error) to all other ranks in the PG, the broadcast is done through TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Note it is critical for all ranks in the PG to be notified about an error as soon as it occurs, so we introduce an errorType of REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a local error) to all other ranks in the PG, the broadcast is done through TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Note it is critical for all ranks in the PG to be notified about an error as soon as it occurs, so we introduce an errorType of REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a local error) to all other ranks in the PG, the broadcast is done through TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…140087)

Summary:
If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach work specific context and debug the RC of the specific failing work/collective

Note it is critical for all ranks in the PG to be notified about an error as soon as it occurs, so we introduce an errorType of REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a local error) to all other ranks in the PG, the broadcast is done through TCPStore currently

Tags:

Pull Request resolved: pytorch#140087
Approved by: https://github.com/kwen2501
@shuqiangzhang
Copy link
Contributor Author

This PR has caused perf drop in some tests, this is most likely resulted from TCPStore read in watchdog thread. Moving it to another thread, e.g., monitor thread which has been already doing tcpstore reads, could avoid the perf drop. Will do it in another fresh PR, since this PR has been old and conflicts resolving is needed

shuqiangzhang added a commit that referenced this pull request Jan 9, 2025
Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 9, 2025
… level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 9, 2025
Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

ghstack-source-id: fb4ea62
Pull Request resolved: #144498
shuqiangzhang added a commit that referenced this pull request Jan 21, 2025
…r code at the PG level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 21, 2025
… level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 22, 2025
…r code at the PG level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 22, 2025
… level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 22, 2025
Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

ghstack-source-id: aca3df0
Pull Request resolved: #144498
shuqiangzhang added a commit that referenced this pull request Jan 23, 2025
…r code at the PG level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 23, 2025
… level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 23, 2025
…r code at the PG level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 23, 2025
… level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 24, 2025
…r code at the PG level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 24, 2025
… level"

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
shuqiangzhang added a commit that referenced this pull request Jan 24, 2025
Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

ghstack-source-id: 3f945c9
Pull Request resolved: #144498
pytorchmergebot pushed a commit that referenced this pull request Jan 24, 2025
…4498)

Summary:
This PR is basically a replacement of
#140087, which caused some perf
drop due to frequent TCPStore check in watchdog thread. The fix is to move the
tcpstore check in monitoring thread

If unhealthy, the user should be able to get the type of errors, e.g.,
timeout,nccl error or remote error.

This API is applied to PG level, compared to the
work.get_future_result() API which is applied to Work Level.
Error detection at PG level is much more convenient for users to handle
the PG failure as a whole, e.g, restarting the PG.

Error handling at the work level is still useful for users to attach
work specific context and debug the RC of the specific failing
work/collective

Note it is critical for all ranks in the PG to be notified about an
error as soon as it occurs, so we introduce an errorType of
REMOTE_ERROR, which is 'broadcasted' from a src rank (which detects a
local error) to all other ranks in the PG, the broadcast is done through
TCPStore currently

Tags:

Pull Request resolved: #144498
Approved by: https://github.com/kwen2501
@github-actions github-actions bot deleted the gh/shuqiangzhang/62/head branch February 9, 2025 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants