-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: aotinductoraot inductoraot inductormodule: crashProblem manifests as a hard crash, as opposed to a RuntimeErrorProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaloncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
import torch
import torch._inductor
# Taken from test_config_option_dont_assume_alignment_cuda
class M(torch.nn.Module):
def forward(self, x):
return x.sin() + x.cos()
N = 64 * 64 * 64 + 64
dtype = torch.float32
arg = torch.randn(N, dtype=dtype, device='cuda')
args = (arg,)
m_arg = torch.zeros(N + 1, dtype=dtype, device='cuda')
m_arg = m_arg[1:]
m_arg.copy_(arg)
m_args = (m_arg,)
fn = M()
opt_fn = torch.compile()(fn)
print(opt_fn(*args))
# This triggers IMA if you stubbed out copy_misaligned_inputs
#print(opt_fn(*m_args))
fn = torch._inductor.aoti_compile_and_package(torch.export.export(M(), args))
print(fn)
model = torch._C._aoti.AOTIModelPackageLoader(fn, "model")
with torch.profiler.profile() as p:
r = model.run(m_args)
print(r)
# Expect to see copy_ here
print(p.key_averages().table(sort_by="cpu_time_total", row_limit=10))
This fails with:
(/home/ezyang/local/a/pytorch-env) [[email protected] ~/local/a/pytorch (d016d437)]$ python n.py
tensor([ 1.1110, 0.6315, -1.4111, ..., 1.2410, 1.3356, 1.3179],
device='cuda:0')
/tmp/torchinductor_ezyang/csyafcbbeawpmvhukquhoe3txwcojexzz4zz25nqzcfuqrlyoayd/ckdvcj5r7cqv6rf52nbdrsecjighl2hqo5xdy52opjbqhjaae2ao.pt2
Traceback (most recent call last):
File "/data/users/ezyang/a/pytorch/n.py", line 32, in <module>
with torch.profiler.profile() as p:
File "/data/users/ezyang/a/pytorch/torch/profiler/profiler.py", line 793, in __exit__
self.stop()
File "/data/users/ezyang/a/pytorch/torch/profiler/profiler.py", line 809, in stop
self._transit_action(self.current_action, None)
File "/data/users/ezyang/a/pytorch/torch/profiler/profiler.py", line 852, in _transit_action
action()
File "/data/users/ezyang/a/pytorch/torch/profiler/profiler.py", line 239, in stop_trace
self.profiler.__exit__(None, None, None)
File "/data/users/ezyang/a/pytorch/torch/autograd/profiler.py", line 369, in __exit__
device_module.synchronize()
File "/data/users/ezyang/a/pytorch/torch/cuda/__init__.py", line 967, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Failed to destroy CUDA event in AOTInductor model: misaligned address
[W1202 10:07:36.226747024 record_function.cpp:253] Requested callback is not found
Segmentation fault (core dumped)
There is some sensitivity to the necessary input size to trigger this (I originally tried 64 * 64 + 64 and this triggered IMA on torch.compile but not AOTInductor).
Versions
main
cc @ptrblck @msaroufim @eqy @chauhang @penguinwu @desertfire @chenyang78
Metadata
Metadata
Assignees
Labels
module: aotinductoraot inductoraot inductormodule: crashProblem manifests as a hard crash, as opposed to a RuntimeErrorProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generaloncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module