Skip to content

Inductor generates sync copy_() on .to(device, non_blocking=False)ย #136260

@IvanKobzarev

Description

@IvanKobzarev

๐Ÿ› Describe the bug

Original complain is from internal users of torch.compile:

import torch

def fn(x):
    return x.to(device="cuda", non_blocking=True)

inp = torch.randn(3, 4)

torch.compile(fn)(inp)

Generates:

   [__graph_code] TRACED GRAPH
   [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
   [__graph_code]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
   [__graph_code]     def forward(self, L_x_: "f32[3, 4]"):
   [__graph_code]         l_x_ = L_x_
   [__graph_code]         
   [__graph_code]          # File: /home/ivankobzarev/oncall-to-non-blocking/test.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
   [__graph_code]         to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
   [__graph_code]         return (to,)
   [__graph_code]         
   [__graph_code] 
   [__graph_code] TRACED GRAPH
   [__graph_code]  ===== __compiled_fn_1 =====
   [__graph_code]  /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
   [__graph_code]     def forward(self, L_x_: "f32[3, 4][4, 1]cpu"):
   [__graph_code]         l_x_ = L_x_
   [__graph_code]         
   [__graph_code]          # File: /home/ivankobzarev/oncall-to-non-blocking/test.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
   [__graph_code]         to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
   [__graph_code]         return (to,)
   [__graph_code]         
   [__graph_code] 
   [__aot_graphs] TRACED GRAPH
   [__aot_graphs]  ===== Forward graph 0 =====
   [__aot_graphs]  /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
   [__aot_graphs]     def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"):
   [__aot_graphs]          # File: /home/ivankobzarev/oncall-to-non-blocking/test.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
   [__aot_graphs]         device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0));  arg0_1 = None
   [__aot_graphs]         convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32);  device_put = None
   [__aot_graphs]         return (convert_element_type,)
   [__aot_graphs]         
   [__aot_graphs] 
W0918 03:36:41.891000 3691669 torch/_inductor/utils.py:1441] [0/0] DeviceCopy in input program
   [__output_code] Output code: 
   [__output_code] # AOT ID: ['0_inference']
   [__output_code] from ctypes import c_void_p, c_long, c_int
   [__output_code] import torch
   [__output_code] import math
   [__output_code] import random
   [__output_code] import os
   [__output_code] import tempfile
   [__output_code] from math import inf, nan
   [__output_code] from torch._inductor.hooks import run_intermediate_hooks
   [__output_code] from torch._inductor.utils import maybe_profile
   [__output_code] from torch._inductor.codegen.memory_planning import _align as align
   [__output_code] from torch import device, empty_strided
   [__output_code] from torch._inductor.async_compile import AsyncCompile
   [__output_code] from torch._inductor.select_algorithm import extern_kernels
   [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
   [__output_code] 
   [__output_code] aten = torch.ops.aten
   [__output_code] inductor_ops = torch.ops.inductor
   [__output_code] _quantized = torch.ops._quantized
   [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
   [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
   [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
   [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
   [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
   [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
   [__output_code] async_compile = AsyncCompile()
   [__output_code] 
   [__output_code] 
   [__output_code] async_compile.wait(globals())
   [__output_code] del async_compile
   [__output_code] 
   [__output_code] def call(args):
   [__output_code]     arg0_1, = args
   [__output_code]     args.clear()
   [__output_code]     assert_size_stride(arg0_1, (3, 4), (4, 1))
   [__output_code]     with torch.cuda._DeviceGuard(0):
   [__output_code]         torch.cuda.set_device(0)
   [__output_code]         buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32)
   [__output_code]         buf0.copy_(arg0_1)
   [__output_code]         del arg0_1
   [__output_code]     return (buf0, )
   [__output_code] 
   [__output_code] 
   [__output_code] def benchmark_compiled_module(times=10, repeat=10):
   [__output_code]     from torch._dynamo.testing import rand_strided
   [__output_code]     from torch._inductor.utils import print_performance
   [__output_code]     arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32)
   [__output_code]     fn = lambda: call([arg0_1])
   [__output_code]     return print_performance(fn, times=times, repeat=repeat)
   [__output_code] 
   [__output_code] 
   [__output_code] if __name__ == "__main__":
   [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
   [__output_code]     compiled_module_main('None', benchmark_compiled_module)
   [__output_code] 

dynamo graph and aot graph contain non_blocking=True, while inductor generates copy_() with default non_blocking=False, which result in sync memory copy

Error logs

No error logs.
Inductor generates copy_()

Minified repro

import torch

def fn(x):
    return x.to(device="cuda", non_blocking=True)

inp = torch.randn(3, 4)

torch.compile(fn)(inp)

TORCH_LOGS="aot,graph_code,output_code"

Versions

Collecting environment information...
PyTorch version: 2.6.0a0+gite248c1d
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A
OS: CentOS Stream 9 (x86_64)
GCC version: (Anaconda gcc) 11.2.0
Clang version: 14.0.6
CMake version: version 3.29.5
Libc version: glibc-2.34
Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.12.0-0_fbk16_zion_7661_geb00762ce6d2-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA PG509-210
GPU 1: NVIDIA PG509-210
GPU 2: NVIDIA PG509-210
GPU 3: NVIDIA PG509-210
GPU 4: NVIDIA PG509-210
GPU 5: NVIDIA PG509-210
GPU 6: NVIDIA PG509-210
GPU 7: NVIDIA PG509-210
Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.9.3
/usr/lib64/libcudnn.so.9.1.1
/usr/lib64/libcudnn_adv.so.9.1.1
/usr/lib64/libcudnn_adv_infer.so.8.9.3
/usr/lib64/libcudnn_adv_train.so.8.9.3
/usr/lib64/libcudnn_cnn.so.9.1.1
/usr/lib64/libcudnn_cnn_infer.so.8.9.3
/usr/lib64/libcudnn_cnn_train.so.8.9.3
/usr/lib64/libcudnn_engines_precompiled.so.9.1.1
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.1.1
/usr/lib64/libcudnn_graph.so.9.1.1
/usr/lib64/libcudnn_heuristic.so.9.1.1
/usr/lib64/libcudnn_ops.so.9.1.1
/usr/lib64/libcudnn_ops_infer.so.8.9.3
/usr/lib64/libcudnn_ops_train.so.8.9.3
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 4
Stepping: 11
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 1801.0000
CPU min MHz: 800.0000
BogoMIPS: 3600.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 3 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 96 MiB (96 instances)
L3 cache: 132 MiB (4 instances)
NUMA node(s): 4
NUMA node0 CPU(s): 0-23,96-119
NUMA node1 CPU(s): 24-47,120-143
NUMA node2 CPU(s): 48-71,144-167
NUMA node3 CPU(s): 72-95,168-191
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] clip-anytorch==2.6.0
[pip3] CoCa-pytorch==0.1.0
[pip3] dalle2-pytorch==1.14.2
[pip3] ema-pytorch==0.5.0
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] onnx==1.16.1
[pip3] open-clip-torch==2.24.0
[pip3] optree==0.12.1
[pip3] pytorch-labs-segment-anything-fast==0.2
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.3.3
[pip3] torch==2.6.0a0+gite248c1d
[pip3] torch-fidelity==0.3.0
[pip3] torch_geometric==2.4.0
[pip3] torchao==0.3.1
[pip3] torchaudio==2.4.0a0+b829e93
[pip3] torchdata==0.7.1a0+958eeb0
[pip3] torchmetrics==1.0.3
[pip3] torchmultimodal==0.1.0b0
[pip3] torchrec==0.8.0a0+3866a49
[pip3] torchtext==0.17.0a0+09e2690
[pip3] torchvision==0.19.0a0+143d078
[pip3] vector-quantize-pytorch==1.14.26
[conda] bert-pytorch 0.0.1a4 dev_0
[conda] blas 1.0 mkl defaults
[conda] clip-anytorch 2.6.0 pypi_0 pypi
[conda] coca-pytorch 0.1.0 pypi_0 pypi
[conda] dalle2-pytorch 1.14.2 pypi_0 pypi
[conda] diffusers-torch 0.18.2 py310h2f386ee_0 defaults
[conda] ema-pytorch 0.5.0 pypi_0 pypi
[conda] functorch 1.14.0a0+b71aa0b pypi_0 pypi
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] magma-cuda116 2.6.1 1 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344 defaults
[conda] mkl-include 2023.1.0 h06a4308_46344 defaults
[conda] mkl-service 2.4.0 py310h5eee18b_1 defaults
[conda] mkl_fft 1.3.8 py310h5eee18b_0 defaults
[conda] mkl_random 1.2.4 py310hdb19cb5_0 defaults
[conda] numpy 1.26.0 pypi_0 pypi
[conda] numpy-base 1.26.4 py310hb5e798b_0 defaults
[conda] open-clip-torch 2.24.0 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch-labs-segment-anything-fast 0.2 pypi_0 pypi
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] pytorch-triton 3.0.0+45fff310c8 pypi_0 pypi
[conda] pytorch-warmup 0.1.1 pypi_0 pypi
[conda] rotary-embedding-torch 0.3.3 pypi_0 pypi
[conda] torch 2.6.0a0+gite248c1d dev_0
[conda] torch-fidelity 0.3.0 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torchao 0.3.1 dev_0
[conda] torchaudio 2.4.0a0+b829e93 dev_0
[conda] torchdata 0.7.1a0+958eeb0 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchmetrics 1.0.3 pypi_0 pypi
[conda] torchmultimodal 0.1.0b0 pypi_0 pypi
[conda] torchrec 0.8.0a0+3866a49 dev_0
[conda] torchtext 0.17.0a0+09e2690 dev_0
[conda] vector-quantize-pytorch 1.14.26 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire

Metadata

Metadata

Assignees

Labels

high prioritymodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions