Skip to content

torch.prod cannot be used with cudagraphs #128396

@drhead

Description

@drhead

🐛 Describe the bug

import torch
from torch import nn
print(torch.__version__)
class ProdLayer(nn.Module):
    def __init__(self):
        super(ProdLayer, self).__init__()
        self.layer = nn.Linear(1024, 1024, device='cuda')

    def forward(self, x: torch.Tensor):
        x = self.layer(x)
        x = torch.prod(x, 1)
        return x

module = ProdLayer()

input = torch.randn((1024,1024), device='cuda')
module = torch.cuda.make_graphed_callables(module, (input,))

When I run the above code, I get this error:

{
	"name": "RuntimeError",
	"message": "CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File .../lib/python3.11/site-packages/torch/cuda/graphs.py:364, in make_graphed_callables(callables, sample_args, num_warmup_iters, allow_unused_input, pool)
    363 with torch.cuda.graph(bwd_graph, pool=mempool):
--> 364     grad_inputs = torch.autograd.grad(
    365         outputs=tuple(o for o in static_outputs if o.requires_grad),
    366         inputs=tuple(i for i in static_input_surface if i.requires_grad),
    367         grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
    368         only_inputs=True,
    369         allow_unused=allow_unused_input,
    370     )
    372 # Constructs a tuple suitable for returning from Graphed.backward:
    373 # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
    374 # I couldn't think of a slick one-liner for this pattern.

File .../lib/python3.11/site-packages/torch/autograd/__init__.py:412, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    411 else:
--> 412     result = _engine_run_backward(
    413         t_outputs,
    414         grad_outputs_,
    415         retain_graph,
    416         create_graph,
    417         inputs,
    418         allow_unused,
    419         accumulate_grad=False,
    420     )
    421 if materialize_grads:

File .../lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:

RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[1], line 18
     15 module = ProdLayer()
     17 input = torch.randn((1024,1024), device='cuda')
---> 18 module = torch.cuda.make_graphed_callables(
     19     module,
     20     (input,),
     21 )

File .../lib/python3.11/site-packages/torch/cuda/graphs.py:363, in make_graphed_callables(callables, sample_args, num_warmup_iters, allow_unused_input, pool)
    351 for static_input_surface, static_outputs, bwd_graph, module_params in zip(
    352     reversed(per_callable_static_input_surfaces),
    353     reversed(per_callable_static_outputs),
   (...)
    357     # For now, assumes all static_outputs require grad
    358     # assert all(o.requires_grad for o in static_outputs), \"Outputs of graphed callables must require grad.\"
    359     static_grad_outputs = tuple(
    360         torch.empty_like(o) if o.requires_grad else None for o in static_outputs
    361     )
--> 363     with torch.cuda.graph(bwd_graph, pool=mempool):
    364         grad_inputs = torch.autograd.grad(
    365             outputs=tuple(o for o in static_outputs if o.requires_grad),
    366             inputs=tuple(i for i in static_input_surface if i.requires_grad),
   (...)
    369             allow_unused=allow_unused_input,
    370         )
    372     # Constructs a tuple suitable for returning from Graphed.backward:
    373     # Pads out the actually-needed grads with Nones in gradient slots for inputs that don't require grad.
    374     # I couldn't think of a slick one-liner for this pattern.

File .../lib/python3.11/site-packages/torch/cuda/graphs.py:184, in graph.__exit__(self, exc_type, exc_value, traceback)
    183 def __exit__(self, exc_type, exc_value, traceback):
--> 184     self.cuda_graph.capture_end()
    185     self.stream_ctx.__exit__(exc_type, exc_value, traceback)

File .../lib/python3.11/site-packages/torch/cuda/graphs.py:82, in CUDAGraph.capture_end(self)
     73 def capture_end(self):
     74     r\"\"\"End CUDA graph capture on the current stream.
     75 
     76     After ``capture_end``, ``replay`` may be called on this instance.
   (...)
     80     which call ``capture_end`` internally.
     81     \"\"\"
---> 82     super().capture_end()

RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Additionally, running this within a no_grad context will produce an error that an input was unused.

It seems a bit odd that a very basic operation like this wouldn't be supported, so I am leaning towards this being a bug, but if it really just isn't supported it should have a clearer error message.

Versions

Collecting environment information...
PyTorch version: 2.3.1
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux trixie/sid (x86_64)
GCC version: (Debian 13.2.0-25) 13.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.38

Python version: 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.7.12-amd64-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 555.42.02
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               12
On-line CPU(s) list:                  0-11
Vendor ID:                            AuthenticAMD
Model name:                           AMD Ryzen 5 5600X 6-Core Processor
CPU family:                           25
Model:                                33
Thread(s) per core:                   2
Core(s) per socket:                   6
Socket(s):                            1
Stepping:                             2
Frequency boost:                      enabled
CPU(s) scaling MHz:                   83%
CPU max MHz:                          4650.2920
CPU min MHz:                          2200.0000
BogoMIPS:                             7386.38
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
L1d cache:                            192 KiB (6 instances)
L1i cache:                            192 KiB (6 instances)
L2 cache:                             3 MiB (6 instances)
L3 cache:                             32 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-11
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] pytorch-triton==2.3.0
[pip3] random-fourier-features-pytorch==1.0.1
[pip3] torch==2.3.1
[pip3] torch-tb-profiler==0.4.3
[pip3] torchaudio==2.3.1
[pip3] torchvision==0.18.1
[pip3] triton==2.3.1
[conda] blas                      1.0                         mkl    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] numpy                     1.26.3          py311h64a7726_0    conda-forge
[conda] pytorch                   2.3.1           py3.11_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] pytorch-triton            2.3.0                    pypi_0    pypi
[conda] random-fourier-features-pytorch 1.0.1                    pypi_0    pypi
[conda] torch                     2.3.0+cu121              pypi_0    pypi
[conda] torch-tb-profiler         0.4.3                    pypi_0    pypi
[conda] torchaudio                2.3.0+cu121              pypi_0    pypi
[conda] torchtriton               2.3.1                     py311    pytorch
[conda] torchvision               0.18.0+cu121             pypi_0    pypi
[conda] triton                    2.2.0                    pypi_0    pypi

cc @mcarilli @ezyang @eellison @peterbell10

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cuda graphsAbility to capture and then replay streams of CUDA kernelstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions