-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
import torch
def fn(x):
result = torch.rrelu(x,0.2,0.8,training=True)
return result
x = torch.randn(4,4,dtype=torch.bfloat16,requires_grad=True)
compiled_fn = torch.compile(fn, backend="inductor")
res = compiled_fn(x)
Error logs
File "/home/junlin/.pt24/lib/python3.10/site-packages/torch/functorch/aot_autograd/functional_utils.py", line 415, in assert_functional_graph
n.args[0] in placeholders
torch.dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: n=copy, n.args[0]=convert_element_type_1, placeholders={primals_1, tangents_1}, graph=graph():
%primals_1 : [num_users=2] = placeholder[target=primals_1]
%tangents_1 : [num_users=1] = placeholder[target=tangents_1]
%empty : [num_users=2] = call_function[target=torch.ops.aten.empty.memory_format](args = ([4, 4],), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cpu, pin_memory: False, memory_format: torch.contiguous_format})
%convert_element_type : [num_users=4] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%primals_1, torch.float32), kwargs = {})
%convert_element_type_1 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%empty, torch.float32), kwargs = {})
%le : [num_users=2] = call_function[target=torch.ops.aten.le.Scalar](args = (%convert_element_type, 0), kwargs = {})
%uniform : [num_users=2] = call_function[target=torch.ops.aten.uniform.default](args = (%convert_element_type, 0.2, 0.8), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type, %uniform), kwargs = {})
%where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le, %mul, %convert_element_type), kwargs = {})
%scalar_tensor : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (1,), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu})
%where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le, %uniform, %scalar_tensor), kwargs = {})
%copy : [num_users=0] = call_function[target=torch.ops.aten.copy.default](args = (%convert_element_type_1, %where_1), kwargs = {})
%convert_element_type_2 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%where, torch.bfloat16), kwargs = {})
%rrelu_with_noise_backward : [num_users=1] = call_function[target=torch.ops.aten.rrelu_with_noise_backward.default](args = (%tangents_1, %primals_1, %empty, 0.2, 0.8, True, False), kwargs = {})
return [convert_element_type_2, rrelu_with_noise_backward]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Minified repro
No response
Versions
Collecting environment information...
PyTorch version: 2.4.0a0+git6510557
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.5 (ssh://gerrit.habana-labs.com:29418/tpc_llvm10 48006f04b6800fa1e655ee09c6a8510b3f9f5d4f)
CMake version: version 3.30.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 5220R CPU @ 2.20GHz
CPU family: 6
Model: 85
Thread(s) per core: 1
Core(s) per socket: 6
Socket(s): 2
Stepping: 0
BogoMIPS: 4389.68
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xsaves arat pku ospke md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 12 MiB (12 instances)
L3 cache: 71.5 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX flush not necessary, SMT disabled
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS; IBPB conditional; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] habana-torch-dataloader==1.18.0.204
[pip3] habana-torch-plugin==1.18.0.204
[pip3] numpy==1.23.1
[pip3] torch==2.4.0a0+git6510557
[pip3] torch_tb_profiler==0.4.0
[pip3] torchaudio==2.4.0a0+69d4077
[pip3] torchdata==0.7.1+5e6f7b7
[pip3] torchtext==0.18.0a0+9bed85d
[pip3] torchvision==0.19.0a0+48b1edf
[conda] Could not collect