Skip to content

Commit afe3c0b

Browse files
committed
add c10d backend register mechanism to support 3rd party backends.
The original behavior of pytorch c10d only supports built-in backends, such as nccl/gloo/mpi. This patch is used to extend the c10d capability to support 3rd party communication libraries which are derived from ProcessGroup base class. related RFC is in: #27955 Through this way, user just need manually import this backend and specify the backend name when invoking torch.distributed.init_process_group(). The proposed logic will check if the backend is registered through torch.distributed.Backend.register_backend(). As for how to develop a new 3rd party backend through cpp extension, pls refer to test/cpp_extensions/cpp_c10d_extension.cpp
1 parent 79d47c1 commit afe3c0b

File tree

8 files changed

+361
-4
lines changed

8 files changed

+361
-4
lines changed

docs/source/distributed.rst

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Distributed communication package - torch.distributed
1010
Backends
1111
--------
1212

13-
``torch.distributed`` supports three backends, each with
13+
``torch.distributed`` supports three built-in backends, each with
1414
different capabilities. The table below shows which functions are available
1515
for use with CPU / CUDA tensors.
1616
MPI supports CUDA only if the implementation used to build PyTorch supports it.
@@ -395,6 +395,26 @@ of 16
395395
.. autofunction:: all_gather_multigpu
396396

397397

398+
Third-party backends
399+
--------------------
400+
401+
Besides the GLOO/MPI/NCCL backends, PyTorch distributed supports third-party backends
402+
through a run-time register mechanism.
403+
For references on how to develop a third-party backend through C++ Extension,
404+
please refer to `Tutorials - Custom C++ and CUDA Extensions <https://pytorch.org/
405+
tutorials/advanced/cpp_extension.html>`_ and `test/cpp_extensions/cpp_c10d_extension.cpp`.
406+
The capability of third-party backends are decided by their own implementations.
407+
408+
The new backend derives from `c10d.ProcessGroup` and registers the backend name and the
409+
instantiating interface through :func:`torch.distributed.Backend.register_backend` when
410+
imported.
411+
412+
When manually importing this backend and invoking :func:`torch.distributed.init_process_group`
413+
with the corresponding backend name, the `torch.distributed` package runs on the new backend.
414+
415+
.. warning::
416+
The support of third-party backend is experimental and subject to change.
417+
398418
Launch utility
399419
--------------
400420

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,7 @@ def print_box(msg):
801801
'include/c10/cuda/impl/*.h',
802802
'include/c10/hip/*.h',
803803
'include/c10/hip/impl/*.h',
804+
'include/c10d/*.hpp',
804805
'include/caffe2/**/*.h',
805806
'include/torch/*.h',
806807
'include/torch/csrc/*.h',
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#include "cpp_c10d_extension.hpp"
2+
3+
#include <map>
4+
5+
namespace c10d {
6+
7+
ProcessGroupTest::WorkTest::~WorkTest() {}
8+
9+
bool ProcessGroupTest::WorkTest::isCompleted() {
10+
return true;
11+
}
12+
13+
bool ProcessGroupTest::WorkTest::isSuccess() const {
14+
return true;
15+
}
16+
17+
bool ProcessGroupTest::WorkTest::wait() {
18+
return true;
19+
}
20+
21+
ProcessGroupTest::ProcessGroupTest(int rank, int size)
22+
: ProcessGroup(rank, size) {}
23+
24+
ProcessGroupTest::~ProcessGroupTest() {}
25+
26+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::broadcast(
27+
std::vector<at::Tensor>& tensors,
28+
const BroadcastOptions& opts) {
29+
return std::make_shared<ProcessGroupTest::WorkTest>();
30+
}
31+
32+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce(
33+
std::vector<at::Tensor>& tensors,
34+
const AllreduceOptions& opts) {
35+
return std::make_shared<ProcessGroupTest::WorkTest>();
36+
}
37+
38+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allreduce_coalesced(
39+
std::vector<at::Tensor>& tensors,
40+
const AllreduceCoalescedOptions& opts) {
41+
throw std::runtime_error("ProcessGroupTest does not support allreduce_coalesced");
42+
}
43+
44+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce(
45+
std::vector<at::Tensor>& tensors,
46+
const ReduceOptions& opts) {
47+
throw std::runtime_error("ProcessGroupTest does not support reduce");
48+
}
49+
50+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather(
51+
std::vector<std::vector<at::Tensor>>& outputTensors,
52+
std::vector<at::Tensor>& inputTensors,
53+
const AllgatherOptions& opts) {
54+
throw std::runtime_error("ProcessGroupTest does not support allgather");
55+
}
56+
57+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather_base(
58+
at::Tensor& outputBuffer,
59+
at::Tensor& inputBuffer,
60+
const AllgatherOptions& opts) {
61+
throw std::runtime_error("ProcessGroupTest does not support allgather_base");
62+
}
63+
64+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::barrier(
65+
const BarrierOptions& opts) {
66+
throw std::runtime_error("ProcessGroupTest does not support barrier");
67+
}
68+
69+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::gather(
70+
std::vector<std::vector<at::Tensor>>& outputTensors,
71+
std::vector<at::Tensor>& inputTensors,
72+
const GatherOptions& opts) {
73+
throw std::runtime_error("ProcessGroupTest does not support gather");
74+
}
75+
76+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::scatter(
77+
std::vector<at::Tensor>& outputTensors,
78+
std::vector<std::vector<at::Tensor>>& inputTensors,
79+
const ScatterOptions& opts) {
80+
throw std::runtime_error("ProcessGroupTest does not support scatter");
81+
}
82+
83+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::reduce_scatter(
84+
std::vector<at::Tensor>& outputTensors,
85+
std::vector<std::vector<at::Tensor>>& inputTensors,
86+
const ReduceScatterOptions& opts) {
87+
throw std::runtime_error("ProcessGroupTest does not support reduce_scatter");
88+
}
89+
90+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::send(
91+
std::vector<at::Tensor>& tensors,
92+
int dstRank,
93+
int tag) {
94+
throw std::runtime_error("ProcessGroupTest does not support send");
95+
}
96+
97+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recv(
98+
std::vector<at::Tensor>& tensors,
99+
int srcRank,
100+
int tag) {
101+
throw std::runtime_error("ProcessGroupTest does not support recv");
102+
}
103+
104+
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::recvAnysource(
105+
std::vector<at::Tensor>& tensor,
106+
int tag) {
107+
throw std::runtime_error("ProcessGroupTest does not support recvAnysource");
108+
}
109+
110+
std::shared_ptr<ProcessGroup> ProcessGroupTest::createProcessGroupTest(
111+
const std::shared_ptr<::c10d::Store>& store,
112+
int rank,
113+
int size,
114+
const std::chrono::duration<float>& timeout) {
115+
return std::make_shared<ProcessGroupTest>(rank, size);
116+
}
117+
118+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
119+
m.def("createProcessGroupTest", &ProcessGroupTest::createProcessGroupTest);
120+
}
121+
122+
} // namespace c10d
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
#include <deque>
6+
#include <exception>
7+
#include <memory>
8+
#include <mutex>
9+
#include <thread>
10+
#include <vector>
11+
12+
#include <pybind11/chrono.h>
13+
14+
#include <c10d/ProcessGroup.hpp>
15+
#include <c10d/Store.hpp>
16+
#include <c10d/Types.hpp>
17+
#include <c10d/Utils.hpp>
18+
19+
namespace c10d {
20+
21+
//
22+
// ProcessGroupTest implements dummy bindings for c10d.
23+
//
24+
25+
class ProcessGroupTest : public ProcessGroup {
26+
public:
27+
class WorkTest : public ProcessGroup::Work {
28+
public:
29+
WorkTest() {}
30+
31+
virtual ~WorkTest();
32+
bool isCompleted() override;
33+
bool isSuccess() const override;
34+
bool wait() override;
35+
36+
protected:
37+
friend class ProcessGroupTest;
38+
};
39+
40+
explicit ProcessGroupTest(int rank = -1, int size = -1);
41+
virtual ~ProcessGroupTest();
42+
43+
std::shared_ptr<ProcessGroup::Work> broadcast(
44+
std::vector<at::Tensor>& data,
45+
const BroadcastOptions& opts = BroadcastOptions()) override;
46+
47+
std::shared_ptr<ProcessGroup::Work> allreduce(
48+
std::vector<at::Tensor>& tensors,
49+
const AllreduceOptions& opts = AllreduceOptions()) override;
50+
51+
std::shared_ptr<ProcessGroup::Work> allreduce_coalesced(
52+
std::vector<at::Tensor>& tensors,
53+
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override;
54+
55+
std::shared_ptr<ProcessGroup::Work> reduce(
56+
std::vector<at::Tensor>& tensors,
57+
const ReduceOptions& opts = ReduceOptions()) override;
58+
59+
std::shared_ptr<ProcessGroup::Work> allgather(
60+
std::vector<std::vector<at::Tensor>>& outputTensors,
61+
std::vector<at::Tensor>& inputTensors,
62+
const AllgatherOptions& opts = AllgatherOptions()) override;
63+
64+
std::shared_ptr<ProcessGroup::Work> allgather_base(
65+
at::Tensor& outputBuffer,
66+
at::Tensor& inputBuffer,
67+
const AllgatherOptions& opts = AllgatherOptions()) override;
68+
69+
std::shared_ptr<ProcessGroup::Work> barrier(
70+
const BarrierOptions& opts = BarrierOptions()) override;
71+
72+
std::shared_ptr<ProcessGroup::Work> gather(
73+
std::vector<std::vector<at::Tensor>>& outputTensors,
74+
std::vector<at::Tensor>& inputTensors,
75+
const GatherOptions& opts = GatherOptions()) override;
76+
77+
std::shared_ptr<ProcessGroup::Work> scatter(
78+
std::vector<at::Tensor>& outputTensors,
79+
std::vector<std::vector<at::Tensor>>& inputTensors,
80+
const ScatterOptions& opts = ScatterOptions()) override;
81+
82+
std::shared_ptr<ProcessGroup::Work> reduce_scatter(
83+
std::vector<at::Tensor>& outputTensors,
84+
std::vector<std::vector<at::Tensor>>& inputTensors,
85+
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
86+
87+
std::shared_ptr<ProcessGroup::Work> send(
88+
std::vector<at::Tensor>& tensors,
89+
int dstRank,
90+
int tag);
91+
92+
std::shared_ptr<ProcessGroup::Work> recv(
93+
std::vector<at::Tensor>& tensors,
94+
int srcRank,
95+
int tag);
96+
97+
std::shared_ptr<ProcessGroup::Work> recvAnysource(
98+
std::vector<at::Tensor>& tensor,
99+
int tag);
100+
101+
// Create a new ProcessGroupTest instance
102+
static std::shared_ptr<ProcessGroup> createProcessGroupTest(
103+
const std::shared_ptr<::c10d::Store>& store,
104+
int rank,
105+
int size,
106+
const std::chrono::duration<float>& timeout);
107+
108+
static void ProcessGroupTestConstructor() __attribute__((constructor)) {
109+
py::object module = py::module::import("torch.distributed");
110+
py::object register_backend = module.attr("Backend").attr("register_backend");
111+
// The first parameter is the backend name used by user in invoking
112+
// torch.distributed.init_process_group().
113+
// Note it could be different with module name. For example, the module
114+
// name is "torch_test" but the backend name is "test".
115+
// The second parameter is the instantiation function.
116+
register_backend("test", py::cpp_function(createProcessGroupTest));
117+
}
118+
119+
};
120+
121+
} // namespace c10d

test/distributed/test_distributed.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import torch.distributed as dist
1818
import torch.nn as nn
1919
import torch.nn.functional as F
20-
from torch.testing._internal.common_utils import TestCase, run_tests
20+
from torch.testing._internal.common_utils import TestCase, run_tests, find_free_port
21+
from torch.distributed.distributed_c10d import _get_default_group
2122
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
2223
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
2324
from torch.testing._internal.common_distributed import simple_sparse_reduce_tests, skip_if_rocm
@@ -31,6 +32,12 @@
3132

3233
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
3334

35+
CPP_EXTENSIONS_WARNING = """
36+
Ninja (https://ninja-build.org) must be available to run C++ extensions tests,
37+
but it could not be found. Install ninja with `pip install ninja`
38+
or `conda install ninja`.
39+
"""
40+
3441
BACKEND = os.environ["BACKEND"]
3542
TEMP_DIR = os.environ["TEMP_DIR"]
3643
INIT_METHOD = os.getenv("INIT_METHOD", "env://")
@@ -150,6 +157,21 @@ def wrapper(*args, **kwargs):
150157
return wrapper
151158

152159

160+
def skip_if_no_ninja(func):
161+
162+
@wraps(func)
163+
def wrapper(*args, **kwargs):
164+
try:
165+
import torch.utils.cpp_extension
166+
torch.utils.cpp_extension.verify_ninja_availability()
167+
except RuntimeError:
168+
print(CPP_EXTENSIONS_WARNING)
169+
return 0
170+
171+
return func(*args, **kwargs)
172+
173+
return wrapper
174+
153175
def require_backend(backends):
154176
if BACKEND not in backends:
155177
return unittest.skip("Test requires backend to be one of %s" % backends)
@@ -2181,6 +2203,45 @@ def _join_and_reduce(self, fn):
21812203
class TestMPI(TestCase, _DistTestBase):
21822204
pass
21832205

2206+
elif BACKEND == "test":
2207+
class TestBackendDynamicLoad(TestCase):
2208+
def setUp(self):
2209+
super(TestBackendDynamicLoad, self).setUp()
2210+
2211+
def _load_test_backend(self):
2212+
temp_dir = tempfile.mkdtemp()
2213+
src = "{}/../cpp_extensions/cpp_c10d_extension.cpp".format(os.path.abspath(os.path.dirname(__file__)))
2214+
extension = torch.utils.cpp_extension.load(
2215+
name="torch_test",
2216+
sources=[src],
2217+
build_directory=temp_dir
2218+
)
2219+
2220+
@skip_if_no_ninja
2221+
def test_backend_apis(self):
2222+
self._load_test_backend()
2223+
2224+
os.environ['WORLD_SIZE'] = '1'
2225+
os.environ['MASTER_ADDR'] = '127.0.0.1'
2226+
os.environ['MASTER_PORT'] = str(find_free_port())
2227+
os.environ['RANK'] = '0'
2228+
2229+
dist.init_process_group(backend='test', init_method='env://', world_size=1, rank=0)
2230+
self.assertEqual(dist.get_rank(), 0)
2231+
self.assertEqual(dist.get_world_size(), 1)
2232+
2233+
process_group = _get_default_group()
2234+
work = process_group.allreduce([torch.rand(1), torch.rand(1)])
2235+
self.assertTrue(work.wait())
2236+
self.assertTrue(work.is_completed())
2237+
self.assertTrue(work.is_success())
2238+
2239+
work = process_group.broadcast([torch.rand(1)])
2240+
self.assertTrue(work.wait())
2241+
self.assertTrue(work.is_completed())
2242+
self.assertTrue(work.is_success())
2243+
2244+
dist.destroy_process_group()
21842245

21852246
if __name__ == "__main__":
21862247
assert (

test/run_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@
148148

149149

150150
if dist.is_available():
151+
DISTRIBUTED_TESTS_CONFIG['test'] = {
152+
'WORLD_SIZE': '1'
153+
}
151154
if not TEST_WITH_ROCM and dist.is_mpi_available():
152155
DISTRIBUTED_TESTS_CONFIG['mpi'] = {
153156
'WORLD_SIZE': '3',

test/test_determination.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def test_torch_file(self):
9292
self.assertEqual(
9393
self.determined_tests(["torch/utils/cpp_extension.py"]),
9494
[
95+
"distributed/test_distributed",
9596
"test_cpp_extensions_aot_ninja",
9697
"test_cpp_extensions_aot_no_ninja",
9798
"test_determination",

0 commit comments

Comments
 (0)