Skip to content

Graph-Safe RNG State Exchange for Tensor Parallelism #113541

@eee4017

Description

@eee4017

🚀 The feature, motivation and pitch

Feature Overview

This proposal aims to introduce an improved system for managing Random Number Generator (RNG) states in PyTorch, particularly focusing on Tensor Parallelism (TP) and Pipeline Parallelism (PP) under CUDA Graph. The current approach, involving set_rng_state and get_rng_state, is not compatible with cudagraph-safe operations. To address this, we propose an index-based state management system that extends and refines the existing method of direct state manipulation (via get_state/set_state in CUDAGenerator). This new system will use indices to register and control states, ensuring cudagraph-safe RNG state operations within the generator.

Background

PyTorch's current mechanism for handling RNG state under CUDA Graph involves a specialized design (see Note [CUDA Graph-safe RNG states]), designed in #47989 by @mcarilli ). It uses the PhiloxCudaStates structure to record the intra-graph offset of each kernel and employs a pointer to a one-element, stream-local int64_t device tensor holding an initial offset value.

Motivation

In TP and PP scenarios, it's common to exchange RNG states, involving setting and getting states to and from Python. This concept, originating from MegatronLM, usually involves a context manager like TensorParallelRNGTracker in torch/distributed/_tensor/random.py that wraps set_rng_state/get_rng_state. However, this approach faces challenges with state changes in CUDA Graph, as PhiloxCudaStates are designed for one initial offset value per graph, but there may exist multiple offsets due to state operations in TP and PP.

class TensorParallelRNGTracker:
  ...
  def _distribute_region(self, spec: DTensorSpec):
    ...
    self._device_handle.set_rng_state(self.rng_states["tensor-parallel-rng"])
    try:
      yield
    finally:
      self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state()

Also mentioned in #83148 . During CUDA graph capture, attempting to set/get the RNG state with set_state can trigger runtime assertions in CUDAGeneratorImpl, as demonstrated in the code snippet below:

import torch

device = torch._C._cuda_getDevice()
cuda_generator = torch.cuda.default_generators[device]

torch.cuda.manual_seed(0)
orig_state = cuda_generator.get_state()

# Function to execute RNG steps with optional state reset
def step(reset_state=False):
    torch.rand((3, 3), device=device)
    if reset_state:  # Restore the original state if requested
        cuda_generator.set_state(orig_state)
    torch.rand((3, 3), device=device)

print("Normal RNG, final offset should increment twice")
torch.cuda.manual_seed(0)
print("Initial offset =", cuda_generator.get_offset())
step(reset_state=False)
print("Final offset =", cuda_generator.get_offset())

print("\nRNG with state restoration, final offset should increment once")
torch.cuda.manual_seed(0)
orig_state = cuda_generator.get_state()
print("Initial offset =", cuda_generator.get_offset())
step(reset_state=True)
print("Final offset =", cuda_generator.get_offset())

print("\nCaptured with CUDA graph, RNG with state restoration.")
torch.cuda.manual_seed(0)
orig_state = cuda_generator.get_state()
print("Initial offset =", cuda_generator.get_offset())
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    step(reset_state=True)
g.replay()
print("Final offset =", cuda_generator.get_offset())

Proposed Solution

To overcome these challenges, we propose the following enhancements:

  1. Preservation of Original API: Preserve the original set_state and get_state API in CUDAGeneratorImpl, but these will not be cudagraph-safe.
  2. Registration and Indexing of Multiple States: We plan to register multiple states within a CUDAGenerator and also copy them to the GPU. The PhiloxCudaStates will record the pointer to the currently used initial offset value for each kernel during capturing.
  3. New Cudagraph-Safe C++ APIs of CUDAGeneratorImpl:
    • register_state_with_index: To register a state and return its index.
    • set_state_with_index/get_state_with_index: To switch between registered states using indices.
  4. New Python API:
    • register_rng_state_with_index: For copying and indexing the current RNG state.
    • Modifications to get_rng_state/set_rng_state to include an use_index flag for index-based control.

This system will enable tools like TensorParallelRNGTracker to use index-based APIs for cudagraph-safe RNG state manipulation, as demonstrated in the provided code snippet.

class TensorParallelRNGTracker:
  def __init__(self, device_type: str = "cuda"):
    self.rng_states["tensor-parallel-rng"] = self._device_handle.register_rng_state_with_index()
  
  def _distribute_region(self, spec: DTensorSpec):
    ...
    self._device_handle.set_rng_state(self.rng_states["tensor-parallel-rng"], use_index=True)
    try:
      yield
    finally:
      self.rng_states["tensor-parallel-rng"] = self._device_handle.get_rng_state(use_index=True)

I've developed a functional piece of code; perhaps I'll submit it as a pull request later.

cc @ngimel @ezyang @ptrblck @csarofeen @ajtulloch

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generaltriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions