Skip to content

Conversation

@Chao1Han
Copy link
Contributor

@Chao1Han Chao1Han commented Sep 20, 2024

Motivation:

Corresponding to [RFC] Intel GPU Distributed Support in PyTorch, this is the first PR to enable distributed support on Intel GPUs ( device name for Intel GPU in PyTorch is XPU). In this PR, we would like to add a new distributed backend ProcessGroupXCCL and implement allreduce as an entrypoint.

Solution:

  1. Add a new distributed backend ProcessGroupXCCL with name xccl which represents XPU Collective Communications Library in this post. This ProcessGroupXCCL inherits from the c10::Backend class like ProcessGroupNCCL.
  2. Implement allreduce for ProcessGroupXCCL in this PR. More collectives for ProcessGroupXCCL will be added next.
  3. Build ProcessGroupXCCL to libtorch_xpu.so with build flag USE_XCCL. This flag is to be ON only when Intel SYCL runtime and oneCCL runtime library is installed, as well as USE_DISTRIBUTED and USE_XPU to be ON .

Example:

Here is a simple example of using spawn to launch XCCL backend and perform allreduce on XPU tensors.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend='xccl', rank=rank, world_size=world_size)
def cleanup():
    dist.destroy_process_group()
def run_allreduce(rank, world_size):
    setup(rank, world_size)
    device = torch.device('xpu:{}'.format(rank))
    x = torch.randn([2, 2], device=device)
    dist.all_reduce(x)
    cleanup()
if __name__ == '__main__':
    world_size = 2
    mp.spawn(run_allreduce, args=(world_size,), nprocs=world_size, join=True)
	

UT Plan:

Add collective unit test cases in test/distributed/test_c10d_xccl.py to verify the correct registration of ProcessGroupXCCL and correctness check of allreduce operation.

Additional Context:

cc @jgong5 @gujinghui @EikanWang @fengyuan14 @guangyey

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136343

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 385c218 with merge base fe0e9fb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 20, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Sep 20, 2024
@Chao1Han Chao1Han marked this pull request as draft September 20, 2024 03:08
@Chao1Han Chao1Han changed the title Xccl process group for Pytorch Add a new distributed backend (XCCL) for Intel GPUs Sep 20, 2024
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this PR introduces a lots of code duplication by renaming NCCL to XCCL
If collective communication principles on XPUs are very different from GPUs, than perhaps indeed it requires a different backend, but if they are the same, perhaps code could be refactored to share more of an infrastructure with NCCL

Also, I would like to challenge the user story here a bit.
It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

@kwen2501
Copy link
Collaborator

I asked a question in the original RFC about the detailed product line supported. Would appreciate comments :)

@EikanWang EikanWang self-requested a review September 23, 2024 09:04
@zhangxiaoli73
Copy link
Contributor

zhangxiaoli73 commented Sep 26, 2024

Hi, @malfet Thank you for your comments. We are currently working on refactoring the code to enhance XCCL's ability to reuse NCCL backend. Meanwhile, we will still need maintain some differences for XCCL.

For backend selection, PyTorch has implemented a feature that allows for automatic backend selection — using NCCL on CUDA devices and XCCL on XPU devices corresponding. Therefore, when users don't want to specify a backend explicitly, this automatic selection will take effect without the need for any programs modification. I think that is what you want, is my understanding right?

Let me show you a common example:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch._utils import _get_device_module
import os

def init_process_group(rank, world_size):
     dist.init_process_group(rank=rank, world_size=world_size)

Users can initialize the global default process group in torch.distributed using the rank ID and world size without backend name. During this initialization, the available backends in the default_device_backend_map will be registered to the process group.

default_device_backend_map: Dict[str, str] = {
    "cpu": GLOO,
    "cuda": NCCL,
    "xpu": XCCL
}

When users call distributed collectives like dist.all_reduce(tensor) in this demo, the process group will dispatch this collective operation to backend based on device type of the input tensor.

def worker(rank, device_type):
    init_process_group(rank=rank, world_size=2)
    device_module = _get_device_module(device_type)
    device_module.set_device(rank)
    tensor = torch.randn(10 if rank == 0 else 20).to(device_type)
    dist.all_reduce(tensor)
    device_module.synchronize(device=rank)
    
if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
    os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
    device_type = "xpu"  // device_type = "cuda" also work
    mp.spawn(worker, nprocs=2, args=(device_type,))

For high-level distributed components like DDP and FSDP, the collective calls are encapsulated within these components, but the underlying methodology remains the same.

@zhangxiaoli73
Copy link
Contributor

Hi, @kwen2501 I responded to the product line in the RFC and would like to reference it here: "For the product line, we are developing on Intel® Data Center GPU Max Series based on Intel® Tiber™ Developer Cloud and will continue to enable next gen of Data Center GPUs afterwards.” Let me know if you have other questions. Thanks.

@kwen2501
Copy link
Collaborator

That's clear, thanks @zhangxiaoli73

Copy link
Collaborator

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short description to the build logic in this PR?
Such as the "if sth exists / sth enabled; else ..." decisions.
It can help us keep a record.
(Would be even better if you could include it in some of the cmake files -- to make them more readable!)
Thanks!

Comment on lines 9 to 12
set(XCCL_ROOT "")
if(DEFINED ENV{CCL_ROOT})
set(XCCL_ROOT $ENV{CCL_ROOT})
endif()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the relationship between XCCL_ROOT and CCL_ROOT?
It looks like CCL_ROOT is the user-facing env?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, CCL_ROOT is a user-facing environment variable that is automatically set by Intel oneCCL.

target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp
PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my education, what do "CCL_ENABLE_ZE;CCL_ENABLE_SYCL" stand for?

Copy link
Contributor Author

@Chao1Han Chao1Han Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those flags are needed in oneCCL (a communication library from Intel) for gcc build, to let oneCCL know that we are building host code by gcc in framework and need device kernel in runtime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chao1Han , please add some comments to elaborate on the motivation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let oneCCL know that we are building host code by gcc in framework and need device kernel in runtime.

@Chao1Han CCL is a library while CCL_ENABLE_ZE;CCL_ENABLE_SYCL a compilation flag. How does the Intel CCL library aware the compilation flag?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those flag is needed in oneCCL to find correct header. We will define those flags in XCCL backend file instead of passing to gcc compiler.

@kwen2501
Copy link
Collaborator

kwen2501 commented Sep 28, 2024

Two more general questions:

  • How are user going to use the XCCL library?
    Via system installation and/or dynamic link loader (e.g. LD_LIBRARY_PATH), right?

  • What components of ProcessGroupNCCL do you wish to see "device neutralized"? We are happy to convert possible ones into general utils.
    cc: @fduwjj @c-p-i-o


if(USE_XPU)
if(USE_XCCL)
append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will libtorch_xpu_distributed_extra_sources be used at different places? It seems like libtorch_xpu_distributed_extra_sources is an unused variable now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of the append_filelist function is to append the files listed in the first variable to the second variable. libtorch_xpu_distributed_extra_sources is defined in build_variables.bzl. The reason ProcessGroupXCCL.cpp is added to this list instead of being directly appended to the Caffe2_XPU_SRCS variable is that we plan to add more utility files to the list later.

target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
set_source_files_properties(
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp
PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let oneCCL know that we are building host code by gcc in framework and need device kernel in runtime.

@Chao1Han CCL is a library while CCL_ENABLE_ZE;CCL_ENABLE_SYCL a compilation flag. How does the Intel CCL library aware the compilation flag?

Comment on lines +183 to +184
if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \
(torch.xpu.is_available() and torch.xpu.device_count() >= x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we provide comprehensive support for the testing? Why does at_least_x_gpu not need to support Intel GPU? Because we do not test it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will provide sufficient tests. This modification is limited to this particular section because the test cases for this PR currently only involve this part. Changes to unit tests will be coordinated with the frontend PRs (e.g., DDP, FSDP, etc.).

Comment on lines +198 to +205
try:
from torch._C._distributed_c10d import ProcessGroupXCCL

ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupXCCL"]
except ImportError:
_XCCL_AVAILABLE = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to refine the logic for different backends a little bit? It seems like we just copy-paste the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this follows the logic of UCC, NCCL, and Gloo. We can consider refining this logic later.

@zhangxiaoli73
Copy link
Contributor

Can you add a short description to the build logic in this PR? Such as the "if sth exists / sth enabled; else ..." decisions. It can help us keep a record. (Would be even better if you could include it in some of the cmake files -- to make them more readable!) Thanks!

Hi, @kwen2501 Sure, the build logic for XCCL backend is:

In distributed backend building(USE_DISTRIBUTED=ON):
  if SYCL runtime and oneCCL runtime are both system installed,
  then 
      building flag USE_XPU=ON , USE_XCCL=ON and USE_C10D_XCCL=ON;
      XCCL backend will be build in libtorch_xpu;
  else
      USE_XCCL=OFF and USE_C10D_XCCL=OFF; 

Note, if you want to disable XCCL backend building, please manually set `USE_XCCL=OFF`.

@gujinghui gujinghui marked this pull request as ready for review October 9, 2024 07:03
@gujinghui
Copy link
Collaborator

Hi @malfet

It looks like this PR introduces a lots of code duplication by renaming NCCL to XCCL If collective communication principles on XPUs are very different from GPUs, than perhaps indeed it requires a different backend, but if they are the same, perhaps code could be refactored to share more of an infrastructure with NCCL

Make sense. We will follow to give an abstraction design to make the code more device-agnostic, and review with you.
But according to our study, it will take long time to complete the code refinement, considering the changes will impact NCCL path, as well.
Is it possible, that we land this PR first to provide XCCL functionality without any changes in NCCL path?
Then, we will give the plan to continue to co-work on the CCL abstraction.

Also, I would like to challenge the user story here a bit. It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

According to @zhangxiaoli73 's comments, this feature should be implemented in PyTorch already.

@gujinghui gujinghui requested a review from malfet October 9, 2024 07:52
@gujinghui
Copy link
Collaborator

Hi @malfet

It looks like this PR introduces a lots of code duplication by renaming NCCL to XCCL If collective communication principles on XPUs are very different from GPUs, than perhaps indeed it requires a different backend, but if they are the same, perhaps code could be refactored to share more of an infrastructure with NCCL

Make sense. We will follow to give an abstraction design to make the code more device-agnostic, and review with you. But according to our study, it will take long time to complete the code refinement, considering the changes will impact NCCL path, as well. Is it possible, that we land this PR first to provide XCCL functionality without any changes in NCCL path? Then, we will give the plan to continue to co-work on the CCL abstraction.

Also, I would like to challenge the user story here a bit. It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

According to @zhangxiaoli73 's comments, this feature should be implemented in PyTorch already.

Hi @malfet @kwen2501,

Any comments for above option? Thanks.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 10, 2024
@kwen2501
Copy link
Collaborator

Any comments for above option? Thanks.

Yeah, fine with me.

Comment on lines 158 to 160
setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none");
setXCCLEnvVar("CCL_LOCAL_RANK", local_rank);
setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this kind of passing may not work well when there are n-dimensional groups, each with different local rank and local world_size.

Would the oneCCL library consider accepting those as API argument?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know your concern. We have asked oneCCL library to detect those information automatically instead of passing by manual, so remove those code.

Comment on lines 1655 to 1684
if device_id is not None and (device_id.index is None or device_id.type != "cuda"):
if device_id is not None and (
device_id.index is None
or (device_id.type != "cuda" and device_id.type != "xpu")
):
raise ValueError(
"init_process_group device_id parameter must be a cuda device with an "
"id, e.g. cuda:0, not just cuda or cpu"
"id, e.g. cuda:0, xpu, not just cuda or xpu or cpu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can just simply the logic:

if device_id is not None and device_id.index is None:
    raise ValueError(
            "init_process_group device_id parameter must be a device with an index"
    )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, changes done.

Comment on lines 466 to 473
nGPUs = torch.cuda.device_count()
nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduced priority.
Can we determine which device count to use based on backend: str passed?
The parent function API:

def init_multigpu_helper(world_size: int, backend: str):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. We will move those if..else in this PR to make code generic.

Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split this PR into 3:

  • First one just introduces some common refactors
  • 2nd one adds XCCL bindings
  • Last one adds testing (which requires some new HW to be added to the infra, shouldn't it?)

Comment on lines +648 to +651
if cmake_cache_vars["USE_XCCL"]:
report("-- Building XCCL library")
else:
report("-- Not using XCCL")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This list keeps growing, I'm not sure if anyone is reading it. Can you elaborate why this is needed other than to mimic behavior similar to NCCL?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results of this list will appear in the final part of the successful build output, but the information may repeat some of the flags printed at the beginning of the build. I think its purpose is to provide some flag information in the log after an incremental build, since small changes to the code do not trigger the message printing from the CMakeLists.txt.

Comment on lines 69 to 73
device_count = (
torch.xpu.device_count()
if torch.xpu.is_available()
else torch.cuda.device_count()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to move it to some torch/testing/_internal as say get_device_count() or something like that? Because you are making similar change in 3-4 places in the code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. We will do some code generalize even for tests.

UCC = 3,
MPI = 4,
CUSTOM = 5,
XCCL = 6,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep CUSTOM as last option?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Changes done.

Comment on lines 511 to 514
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we invert the conditional? Or perhaps add inline bool backendSupportsSequenceNumbers() and call it here and on line 533 as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, make sense to put the condition check into a function instead of a long if ...else.

Comment on lines 296 to 310
int getXCCLEnvVar(std::string envVarName) {
char* stringValue = std::getenv(envVarName.c_str());
if (stringValue != nullptr) {
try {
int val = std::stoi(stringValue);
return val;
} catch (std::exception& e) {
TORCH_CHECK(
false,
"Invalid value for environment variable: " + std::string(envVarName));
}
} else {
return -1;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not look specific to XCCL, does it? Can it go into some sort of a generic header?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this function is not specific to XCCL but it's not needed right now, due to some code logic change in XCCL and then no need to parse env variables.

@zhangxiaoli73
Copy link
Contributor

Can we split this PR into 3:

  • First one just introduces some common refactors
  • 2nd one adds XCCL bindings
  • Last one adds testing (which requires some new HW to be added to the infra, shouldn't it?)

@malfet Make sense. This PR is too large to review, and we will plan to split PR to several small ones and start from the first point: common code refactor and generalization.

We're a bit unsure whether to continue refactoring in this PR or to create a new one. Welcome your comments.

@Chao1Han Chao1Han requested review from kwen2501 and malfet October 21, 2024 02:53
@wconstab
Copy link
Contributor

We're a bit unsure whether to continue refactoring in this PR or to create a new one. Welcome your comments.

perhaps it would be better to start from writing an RFC issue or a google doc with proposed refactoring steps, we can chat on slack and discuss until it looks good, then start to implement them? cc @kwen2501

@EikanWang
Copy link
Collaborator

perhaps it would be better to start from writing an RFC issue or a google doc with proposed refactoring steps, we can chat on slack and discuss until it looks good, then start to implement them? cc @kwen2501

@wconstab , thanks. It is helpful to ensure we are on the same page first. @Chao1Han , @zhangxiaoli73 , let's do it first.

@kwen2501
Copy link
Collaborator

kwen2501 commented Nov 1, 2024

Just thinking out loud here, as discussed in our meeting, PyTorch 2.5 can auto-load/import a backend's module if detected from the installed python package environment. This would eliminate the need for users to do something like import torch-intel in their script.

Following that capability, it seems the ProcessGroupXCCL implementation in this PR can be put in a torch-intel module as like implementation for other non-distributed ops. And when PyTorch auto loads the torch-intel module, the ProcessGroupXCCL would get registered with torch.distributed's Backend dispatch. To that end, we can collaborate on making sure the auto-registration works smoothly, when we get there.

@Chao1Han , what do you think?

@zhangxiaoli73
Copy link
Contributor

Thanks, @kwen2501! You provided a great alternative with autoload. For now, we will focus on backend generalization first and then return back to address how to handle XCCL.

pytorchmergebot pushed a commit that referenced this pull request Nov 19, 2024
Citing @malfet's [comment](#136343 (review)) in #136343
> It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

This PR makes the backend specification ("nccl", "gloo") optional when user provides a `devce_id` to `init_process_group` (the acceptance of `device_id` has been previously supported for the purpose of eager init).

New user experience:
```
device = torch.device(device_type, rank % device_count)
dist.init_process_group(device_id=device)
```

The line of `device = torch.device(...)` is anyway needed because user would use it for tensor creation etc.

Pull Request resolved: #140963
Approved by: https://github.com/wconstab
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
Citing @malfet's [comment](pytorch#136343 (review)) in pytorch#136343
> It would be great, if users do not have to modify their programs for every new backend, but rather use with torch.device('xpu'): and keep rest of the code unchanged.

This PR makes the backend specification ("nccl", "gloo") optional when user provides a `devce_id` to `init_process_group` (the acceptance of `device_id` has been previously supported for the purpose of eager init).

New user experience:
```
device = torch.device(device_type, rank % device_count)
dist.init_process_group(device_id=device)
```

The line of `device = torch.device(...)` is anyway needed because user would use it for tensor creation etc.

Pull Request resolved: pytorch#140963
Approved by: https://github.com/wconstab
@Chao1Han
Copy link
Contributor Author

close it due to #141856 merged.

@Chao1Han Chao1Han closed this Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

10 participants