-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Motivation
This RFC aims to propose a design for a series of generic device runtime APIs tailored for stream-based accelerators to help users simplify the runtime code written for different devices (including out-of-tree devices, a registration mechanism is provided for them).
Background
Currently, we can separate stream-based accelerators runtime into 6 components:
DeviceStreamEventGeneratorGuardAllocator
PyTorch already provided some generic APIs for Stream, Event, Guard, and Allocator. For Stream and Event, following the design of generic Stream and Event, we can write a device-agnostic code for Python:
event = torch.Event(device_type)
stream = torch.Stream(device_type)
event.record(stream)
stream.synchronize()A simple summary related to device-agnostic code is listed below.
- Stream (provided in C++ and Python)
- Event (provided in C++ and Python)
- Device (provided in C++, but lacks of some functionality)
- Guard (provided in C++)
However, PyTorch lacks some generic device-agnostic APIs mainly including two scenarios:
- Get/Set the status of
DeviceandStreamon the Python side, the functionality similar totorch.xpu.set_device(1)andtorch.xpu.current_stream()but accepts device type as a parameter rather than XPU-specific; - Device/Stream Guard on the Python side which is workable for each type of device.
Usage
Since there are no generic device-agnostic APIs, how can a device-agnostic code be written in other PyTorch components? We found two designs in PyTorch that can cover this usage.
- for FSDP, it uses
getattrand_register_device_moduleregistration mechanism to handle different devices, like
backend = getattr(torch, device_type)
if backend.is_available():
backend.set_device(1)
backend.synchronize()- for Inductor, we propose a device interface registration mechanism, like
device_interface = get_interface_for_device(device_type)
if device_interface.is_available():
device_interface.set_device(1)
device_interface.synchronize()These two methods are not unified in PyTorch yet.
Design
To simplify and unify the code, this RFC aims to propose a design for a series of generic device runtime APIs tailored for stream-based accelerators for different devices (including out-of-tree devices).
As described above, PyTorch already provides the generic code, torch.Stream and torch.Event, for Stream and Event respectively in Python. Furthermore, no Python code is provided for Guard and lacks some APIs to cover the Device and Stream status.
So, we propose a design to cover these missing device-agnostic runtime APIs, like the codes below.
import torch
device_type = tensor.device.type # maybe cuda, xpu, mps, mtia, and privateuser1...
assert(torch.has_accelerator(device_type ), "No available accelerator detected!")
stream = torch.current_stream(device_type )
torch.set_device(0, device_type )
d1 = torch.maybe_exchange_device(1, device_type )
s1 = torch.Stream(device_type )
with torch.DeviceGuard(2, device_type ):
d1 = torch.current_device(device_type )
with torch.StreamGuard(s1):
s2 = torch.current_stream(device_type )
...The inspiration for this design comes from the design of generic Stream and Event. PyTorch promotes torch.xxx.Stream & torch.xxx.Event to torch.Stream & torch.Event and make the later device-agnostic. According to this design, we list the device-agnostic Python runtime APIs below which are the most used filtered from some popular repos:
| Device-specific runtime APIs torch.xxx.foo | Device-agnostic runtime APIs torch.foo |
torch.xxx.set_device |
torch.set_device |
torch.xxx.current_device |
torch.current_device |
torch.xxx.device_count |
torch.device_count |
torch.xxx.is_available |
torch.has_accelerator |
torch.xxx.exchange_device |
torch.exchange_device |
torch.xxx.maybe_exchange_device |
torch.maybe_exchange_device |
torch.xxx.set_stream |
torch.set_stream |
torch.xxx.current_stream |
torch.current_stream |
Our goal is torch.foo can cover the most common runtime scenarios and usages. And using if/else statement and torch.xxx.foo as an additional supplement in other situations.
NB: We will not unify a device-agnostic API for some backend-specific APIs, like torch.cuda.default_stream, as other backends have no default stream concept. Due to the significant differences in device properties of each backend, get_device_properties will also not be involved at this stage.
Simple Version: for more convenience, device-agnostic API can no longer accept the device type and parse the device type based on getAccelerator. So the above common code can be simplified as this:
import torch
assert(torch.has_accelerator(), "No available accelerator detected!")
stream = torch.current_stream()
torch.set_device(0)
d1 = torch.maybe_exchange_device(1)
s1 = torch.Stream()
with torch.DeviceGuard(2):
d1 = torch.current_device()
with torch.StreamGuard(s1):
s2 = torch.current_stream()
...Obviously, this can greatly simplify the code and save efforts for the users to migrate their code to follow this design. But the drawback is that it relies on an assumption there is only one type of accelerator on the machine. This is an open issue on demands of the feedback of PyTorch community.
I personally prefer these device-agnostic APIs no longer need to accept the device type as input. The reasons are,
- Currently, it is enough for the PyTorch binary build to only support one accelerator type;
- It is easy to expand these device-agnostic APIs to handle multiple types of accelerator scenarios;
- The user also can use device-specific APIs
torch.xxx.footo handle multi types of accelerator scenarios instead oftorch.foo; - The design does NOT break the previous design philosophies. Since
torch.fooare only used for the accelerator excluding the CPU unliketorch.emptyneeds to specificdeviceparameter to inform the user where the empty tensor would be created.
Also, I list the pros and cons of simple version here to help us to make a decision:
Pros:
torch.foowill have the same input argument astorch.xxx.foo, bringing a better user experience;- more concise, facilitate the developer to write a device-agnostic code.
Cons: - no obvious drawbacks.
Also, in some situations, the users would like to check or know what type of accelerator they are using. To handle this scenario, we provide an extra API torch.current_accelerator to return the type of accelerator as a string according to current Accelerator. It can help users to handle the specific situations, like default stream.
if torch.has_accelerator():
if torch.current_accelerator() == "cuda":
stream = torch.cuda.default_stream()
else:
stream = torch.Stream()
...Additional context
We will implement this design on top of DeviceGuardImplInterface. It also provides a registration mechanism for out-of-tree devices.
These device-agnostic runtime APIs should accept the same input type (maybe torch.device, str, int, or None) as torch.xxx.foo. These two APIs below should be equivalent when XPU is available.
- torch.xpu.set_device(1)
- torch.set_device(1) # based on getAccelerator.
Open: Is it necessary to pass a device type to these designed APIs as an input argument? Only one accelerator sounds enough for most people and most situations.
Besides, we will help
- unify FSDP and Inductor code using these new APIs, and
- investigate how to unify the device-agnostic API related to
GeneratorandAllocator.
cc @albanD @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @EikanWang @gujinghui
Metadata
Metadata
Assignees
Labels
Type
Projects
Status