@@ -24,8 +24,9 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
2424 bool simulate_error,
2525 int rank,
2626 c10d::OpType opType,
27- uint64_t seq)
28- : WorkNCCL(" 0" , " default_pg" , device, rank, opType, seq),
27+ uint64_t seq,
28+ bool isP2P)
29+ : WorkNCCL(" 0" , " default_pg" , device, rank, opType, seq, isP2P),
2930 simulateError_ (simulate_error) {}
3031
3132 std::exception_ptr checkForNCCLErrors () override {
@@ -65,12 +66,18 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
6566 at::Device& device,
6667 int rank,
6768 c10d::OpType opType,
69+ bool isP2P,
6870 const char * profilingTitle,
6971 const std::vector<at::Tensor>& inputs = {},
7072 const std::vector<at::Tensor>& outputs = {},
7173 bool record = false ) override {
7274 return c10::make_intrusive<WorkNCCLSimulateErrors>(
73- device, simulateError_, rank, opType, seqCollective_);
75+ device,
76+ simulateError_,
77+ rank,
78+ opType,
79+ isP2P ? seqP2P_ : seqCollective_,
80+ isP2P);
7481 }
7582
7683 size_t getNCCLCommCacheSize () {
@@ -96,8 +103,9 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL {
96103 bool set_timedout_error,
97104 int rank,
98105 c10d::OpType opType,
99- uint64_t seq)
100- : WorkNCCL(" 0" , " default_pg" , device, rank, opType, seq),
106+ uint64_t seq,
107+ bool isP2P)
108+ : WorkNCCL(" 0" , " default_pg" , device, rank, opType, seq, isP2P),
101109 setTimedoutError_ (set_timedout_error) {}
102110
103111 private:
@@ -127,12 +135,18 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
127135 at::Device& device,
128136 int rank,
129137 c10d::OpType opType,
138+ bool isP2P,
130139 const char * profilingTitle,
131140 const std::vector<at::Tensor>& inputs = {},
132141 const std::vector<at::Tensor>& outputs = {},
133142 bool record = false ) override {
134143 return c10::make_intrusive<WorkNCCLTimedoutErrors>(
135- device, setTimedoutError_, rank, opType, seqCollective_);
144+ device,
145+ setTimedoutError_,
146+ rank,
147+ opType,
148+ isP2P ? seqP2P_ : seqCollective_,
149+ isP2P);
136150 }
137151
138152 void setTimedoutError () {
0 commit comments