Skip to content

Commit 1ad6e7a

Browse files
committed
fix sequence number in execution trace dump
1 parent de4c2a3 commit 1ad6e7a

File tree

5 files changed

+154
-53
lines changed

5 files changed

+154
-53
lines changed

test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
2424
bool simulate_error,
2525
int rank,
2626
c10d::OpType opType,
27-
uint64_t seq)
28-
: WorkNCCL("0", "default_pg", device, rank, opType, seq),
27+
uint64_t seq,
28+
bool isP2P)
29+
: WorkNCCL("0", "default_pg", device, rank, opType, seq, isP2P),
2930
simulateError_(simulate_error) {}
3031

3132
std::exception_ptr checkForNCCLErrors() override {
@@ -65,12 +66,18 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
6566
at::Device& device,
6667
int rank,
6768
c10d::OpType opType,
69+
bool isP2P,
6870
const char* profilingTitle,
6971
const std::vector<at::Tensor>& inputs = {},
7072
const std::vector<at::Tensor>& outputs = {},
7173
bool record = false) override {
7274
return c10::make_intrusive<WorkNCCLSimulateErrors>(
73-
device, simulateError_, rank, opType, seqCollective_);
75+
device,
76+
simulateError_,
77+
rank,
78+
opType,
79+
isP2P ? seqP2P_ : seqCollective_,
80+
isP2P);
7481
}
7582

7683
size_t getNCCLCommCacheSize() {
@@ -96,8 +103,9 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
96103
bool set_timedout_error,
97104
int rank,
98105
c10d::OpType opType,
99-
uint64_t seq)
100-
: WorkNCCL("0", "default_pg", device, rank, opType, seq),
106+
uint64_t seq,
107+
bool isP2P)
108+
: WorkNCCL("0", "default_pg", device, rank, opType, seq, isP2P),
101109
setTimedoutError_(set_timedout_error) {}
102110

103111
private:
@@ -127,12 +135,18 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
127135
at::Device& device,
128136
int rank,
129137
c10d::OpType opType,
138+
bool isP2P,
130139
const char* profilingTitle,
131140
const std::vector<at::Tensor>& inputs = {},
132141
const std::vector<at::Tensor>& outputs = {},
133142
bool record = false) override {
134143
return c10::make_intrusive<WorkNCCLTimedoutErrors>(
135-
device, setTimedoutError_, rank, opType, seqCollective_);
144+
device,
145+
setTimedoutError_,
146+
rank,
147+
opType,
148+
isP2P ? seqP2P_ : seqCollective_,
149+
isP2P);
136150
}
137151

138152
void setTimedoutError() {

torch/csrc/distributed/c10d/ParamCommsUtils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
121121
worldSize); \
122122
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
123123
std::initializer_list<const c10::IValue> paramList = { \
124-
c10::IValue(seq), \
124+
seq, \
125125
pgName, \
126126
rank, \
127127
collName, \
@@ -163,7 +163,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
163163
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
164164
std::initializer_list<const c10::IValue> paramList = { \
165165
c10::IValue(InputTensors), \
166-
c10::IValue(seq), \
166+
seq, \
167167
pgName, \
168168
rank, \
169169
collName, \

0 commit comments

Comments
 (0)