Skip to content

[RFC] Intel GPU Distributed Support in PyTorch #135791

@zhangxiaoli73

Description

@zhangxiaoli73

🚀 The feature, motivation and pitch

Motivation

This Request for Comments (RFC) document aims to propose and discuss Intel GPU distributed support in PyTorch, introducing Intel GPU distributed backend on top of Intel collective communications library(Intel® oneCCL) to PyTorch community.

This initiative begins with the oneCCL integration into PyTorch with a new distributed backend. This marks a significant stride towards Intel GPU could be a robust computational backend in PyTorch distributed scenario. The RFC outlines a high-level design strategy for this integration.

Design

1. Distributed Backend on Intel GPU

In the current design, PyTorch distributed utilizes c10d::ProcessGroup class as an interface to manage multiple communication backends (inherited from c10d::Backend) and provide collective APIs that can be dispatched based on device type and backend.

Regarding per-backend implementation, c10d::ProcessGroupNCCL targets the CUDA device with backend name “nccl”. Similarly, we would like add c10::ProcessGroupXCCL on Intel GPU device with new backend name xccl. NOTE: The device name for Intel GPU in PyTorch is XPU. Therefore, xccl represents XPU Collective Communications Library in this post.

We can visualize this design as below:
P1

In frontend, we follow current register backend methodology in PyTorch, that is register backend to ProcessGroup by device_type, backend_type and backend_class. Each backend must define its own collective implementation, like allreduce, reducescatter, allgather, etc.

The scope of XCCL Backend is as follows:

  • Features:
    • Collectives: broadcast, allreduce, allreduce_coalesced, reduce, allgather, _allgather_base, allgather_coalesced, allgather_into_tensor_coalesced, gather, scatter, reduce_scatter, _reduce_scatter_base, reduce_scatter_tensor_coalesced, alltoall_base, alltoall, send, recv, barrier.
    • Data types: FP32, BF16, FP16, FP64, INT64, INT32 and INT8
    • Reduction types: SUM, MIN, MAX, PRODUCT
  • Intel® Data Center GPU Max Series, scale-up on multiple devices
  • Linux only
  • Pip only with pre-built packages @ https://download.pytorch.org/

2. Intel Distributed Backend built with oneCCL

Except for SYCL runtime library, XCCL Backend also depends on Intel oneCCL runtime library. From the perspective of engineering, XCCL Backend will be built in libtorch_xpu.so by dynamic linking oneCCL runtime library libccl.so. Besides, some PyTorch C++ frontend code that needs to bind will be added to libtorch_python.so.

We can visualize this call stack as below:

p4

PR Plan

The code changes involve some parts of the PyTorch. To be clear and concise, we will split those changes into 3 PRs for easy to review with below priority.

Alternatives

No response

Additional context

No response

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @gujinghui @EikanWang @fengyuan14 @guangyey

Metadata

Metadata

Assignees

Labels

module: xpuIntel XPU related issuesoncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions