Skip to content

Compiling flex attention with batch dependent block mask and dynamic shapes #136196

@SamGalanakis

Description

@SamGalanakis

🐛 Describe the bug

Compilation of flex attention with dynamic shapes enabled doesn't work when the BlockMask depends on the batch dimension. My usecase is document packing so both batch size will change and the per batch mask will differ. This is a follow up from #134560.
In the code below it works fine for block_mask_type = "causal", constantly recompiles for block_mask_type = "batch_flip_causal" and errors out before first iteration with "dense". This was ran on pytorch/pytorch-nightly:2.6.0.dev20240913-cuda12.4-cudnn9-devel.

Also if you have any tips for making the BlockMask from the dense mask more efficiently would also be useful.

import torch
from tqdm import tqdm
from torch.nn.attention.flex_attention import (
    BlockMask,
    _score_mod_signature,
    flex_attention,
    create_block_mask,
)
from torch import nn
import logging


torch._logging.set_logs(dynamo=logging.DEBUG)
torch._dynamo.config.verbose = True
torch.autograd.set_detect_anomaly(True)
torch._inductor.config.debug = True

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def batch_flip_causal(b, h, q_idx, kv_idx):
    return (q_idx >= kv_idx) & b % 2 == 0


class SelfAttentionLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        n_head: int,
        dropout: float = 0.0,
        bias=False,
    ):
        super().__init__()
        assert (
            dim % n_head == 0
        ), f"dim must be divisible by n_head found: {dim} and {n_head}"

        # key, query, value projections for all heads, but in a batch
        self.qkv = nn.Linear(dim, 3 * dim, bias=bias)
        # output projection
        self.c_proj = nn.Linear(dim, dim, bias=bias)
        # regularization

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = n_head
        self.head_dim = dim // n_head

        self.n_embd = dim
        self.dropout = dropout

    def forward(
        self,
        x,
        score_mod: None | _score_mod_signature = None,
        block_mask: None | BlockMask = None,
    ):
        B, T, C = (
            x.size()
        )  



        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.n_head, self.head_dim)
        qkv = qkv.permute(
            2, 0, 3, 1, 4
        )  
        q, k, v = qkv

        y = flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.resid_dropout(self.c_proj(y))
        return y


@torch.no_grad()
def create_block_mask_from_dense(
    dense_mask: torch.Tensor, device, block_size: int = 128, compile: bool = False
) -> BlockMask:
    B, S = dense_mask.shape

    def mask_mod_partial(b, h, q_idx, kv_idx):
        return ~dense_mask[b, q_idx]

    return create_block_mask(
        B=B,
        BLOCK_SIZE=block_size,
        mask_mod=mask_mod_partial,
        H=None,
        Q_LEN=S,
        KV_LEN=S,
        _compile=compile,
        device=device,
    )


torch.set_float32_matmul_precision('high')
model = SelfAttentionLayer(
    dim = 512,
    n_head = 8,
    dropout = 0
).cuda()



compile_dynamic = True
compile_block_mask = True
sequence_len = 256
block_mask_type = "dense"



model.compile(mode="default", dynamic=compile_dynamic)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for batch_shape in tqdm(range(2,55)):
    
    match block_mask_type:
        case "causal":
            block_mask = create_block_mask(
            mask_mod=batch_flip_causal,
            B=None,
            H=None,
            Q_LEN=sequence_len,
            KV_LEN=sequence_len,
            device="cuda",
            _compile=compile_block_mask,
        )
        case "dense":        
            rand_mask = torch.randint(0,2,(batch_shape,sequence_len)).to("cuda").bool()
            block_mask = create_block_mask_from_dense(rand_mask,device="cuda",compile=compile_block_mask)
        case "batch_flip_causal":
            block_mask = create_block_mask(
                mask_mod=batch_flip_causal,
                B=batch_shape,
                H=None,
                Q_LEN=sequence_len,
                KV_LEN=sequence_len,
                device="cuda",
                _compile=compile_block_mask,
            )
        case _:
            raise ValueError("Invalid block_mask_type")
    x = torch.randn(batch_shape, sequence_len, 512).to("cuda")
    y = model(x,score_mod = None, block_mask = block_mask)
    loss = y.mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Error in dense case:

Traceback (most recent call last):
  File "/workspaces/methylation_prediction/test_dynamic.py", line 149, in <module>
    y = model(x,score_mod = None, block_mask = block_mask)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1734, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1278, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1073, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1525, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 586, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1354, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1425, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 476, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 662, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1370, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 571, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 879, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1953, in compile_to_fn
    return self.compile_to_module().call
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1879, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1907, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 2856, in load_by_key_path
    mod = _reload_python_module(key, path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/or/cor2irb7stdmyjvmeeltisryx6a624z2ymt6sp74zhhvbxrvrl4v.py", line 521, in <module>
    async_compile.wait(globals())
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 286, in wait
    scope[key] = result.result()
                 ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3321, in result
    result = self.future.result()
             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
SubprocException: An exception occurred in a subprocess:

triton.compiler.errors.CompilationError: at 56:47:
        m = offs_m
        n = offs_n

    post_mod_scores = (qk)


    if CHECK_BLOCK_BOUNDARY:
        # Mask out the elements that are out of the KV_LEN for non divisible seqlen.
        post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))

    if not IS_FULL_BLOCKS:
        tmp0 = tl.load(in_ptr6 + (m) + (off_z)*s1) == 0
                                               ^
NameError('s1 is not defined')

The above exception was the direct cause of the following exception:

triton.compiler.errors.CompilationError: at 57:28:
                acc, l_i, m_i,
                # Offsets
                off_z, off_h, offs_m, offs_n,
                MATMUL_PRECISION, RCP_LN2,
                IS_FULL_BLOCKS,
            )
        else:
            # Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
            # it's on par or slightly faster than only applying to the last block in fwd.
            # However, we choose different strategy for bwd, where we only apply mod & mask
            # to the last block because it's faster a lot.
            acc, l_i, m_i = forward_block_mn(
                            ^

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_worker/subproc_pool.py", line 270, in do_job
    result = job()
             ^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 68, in _worker_compile_triton
    load_kernel().precompile(warm_cache_only=True)
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 244, in precompile
    compiled_binary, launcher = self._precompile_config(
                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 428, in _precompile_config
    triton.compile(*compile_args, **compile_kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 154:20:
    )
    V_block_ptr = tl.make_block_ptr(
        base=V,
        shape=(KV_LEN, V_HEAD_DIM),
        strides=(stride_vn, stride_vk),
        offsets=(kv_start, 0),
        block_shape=(BLOCK_N, V_HEAD_DIM),
        order=(1, 0)
    )
    offs_n = kv_start + tl.arange(0, BLOCK_N)

    acc, l_i, m_i = forward_inner(
                    

Versions

Collecting environment information...
PyTorch version: 2.5.0.dev20240912+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.3
Libc version: glibc-2.35

Python version: 3.11.10 | packaged by conda-forge | (main, Sep 10 2024, 11:01:28) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-118-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 160
On-line CPU(s) list: 0-159
Vendor ID: GenuineIntel
Model name: Intel Xeon Processor (Icelake)
CPU family: 6
Model: 106
Thread(s) per core: 2
Core(s) per socket: 40
Socket(s): 2
Stepping: 0
BogoMIPS: 4200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 5 MiB (160 instances)
L1i cache: 5 MiB (160 instances)
L2 cache: 320 MiB (80 instances)
L3 cache: 32 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-79
NUMA node1 CPU(s): 80-159
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==2.0.2
[pip3] optree==0.12.1
[pip3] pytorch-triton==3.1.0+5fe38ffd73
[pip3] torch==2.5.0.dev20240912+cu124
[pip3] torch_scatter==2.1.2
[pip3] torchaudio==2.5.0.dev20240912+cu124
[pip3] torchelastic==0.2.2
[pip3] torcheval==0.0.7
[pip3] torchvision==0.20.0.dev20240912+cu124
[pip3] triton==3.0.0
[conda] numpy 2.0.2 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch-triton 3.1.0+5fe38ffd73 pypi_0 pypi
[conda] torch 2.5.0.dev20240912+cu124 pypi_0 pypi
[conda] torch-scatter 2.1.2 pypi_0 pypi
[conda] torchaudio 2.5.0.dev20240912+cu124 pypi_0 pypi
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torcheval 0.0.7 pypi_0 pypi
[conda] torchvision 0.20.0.dev20240912+cu124 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions