Skip to content

torch.compiled model output gets overwritten despite tensor.detach() #104435

@HDCharles

Description

@HDCharles

🐛 Describe the bug

related to

msg = (
"Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. "
f"Stack trace: {stack_trace}. "
"To prevent overwriting, clone the tensor outside of torch.compile() "
"before running the model again."

at times when you would get this error, if instead of doing out = model(input), you do out = model(input).detach() to try to fix the error, you suppress the error message while not fixing the problem. Specifically the value of out will change if you run model(input).detach() again. you have to do model(input)+0 or something similar to actually fix the problem.

at a high level i think this bug is either
A) about tensor.detach() suppressing an error message without fixing the error.
B) model outputs getting overwritten despite tensor.detach()
depending on whether B is expected or not.

either the error message should not be suppressed or the output value should function as expected.

@eellison

Error logs

n/a

Minified repro

n/a

my own repro, try running with/without @torch.compile() and with/without .detach()
running as is should either throw the error message or give the same result as running without @torch.compile

@torch.compile(mode='reduce-overhead')
def foo(x):
    return x * x * x

inp = torch.rand([2], device="cuda")
out = foo(inp).detach()
sum_val_1 = out+out
out2 = foo(inp).detach()
sum_val_2 = out+out
print(sum_val_1, sum_val_2, out2 + out2)
assert  sum_val_1.sum()==sum_val_2.sum()

Versions

Collecting environment information...
PyTorch version: 2.1.0a0+git1dba81f
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.26.1
Libc version: glibc-2.31

Python version: 3.9.5 (default, Jun 4 2021, 12:28:51) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.0-1019-aws-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
Stepping: 7
CPU MHz: 2500.000
BogoMIPS: 5000.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==0.960
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.1
[pip3] pytorch-triton==2.1.0+440fd1bf20
[pip3] torch==2.1.0a0+git1dba81f
[pip3] torchvision==0.16.0a0+e5bf7cf
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-include 2023.0.0 h06a4308_25399
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.1 pypi_0 pypi
[conda] pytorch-triton 2.1.0+440fd1bf20 pypi_0 pypi
[conda] torch 2.1.0a0+git1dba81f dev_0
[conda] torchvision 0.16.0a0+e5bf7cf dev_0

cc @mcarilli @ezyang @eellison @peterbell10 @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @wconstab

Metadata

Metadata

Assignees

Labels

module: cuda graphsAbility to capture and then replay streams of CUDA kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions