-
Notifications
You must be signed in to change notification settings - Fork 27k
Closed
Labels
oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Creating a torch.tensor inside torch.compile works till 2.1.0.dev20230303+cpu nighlty build but it starts failing from 2.1.0.dev20230304+cpu nightly build.
The repro code and error log are described below.
Error logs
Traceback (most recent call last):
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
transformations(instructions, code_options)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in run
super().run()
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 619, in run
and self.step()
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 583, in step
getattr(self, inst.opname)(inst)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 349, in wrapper
return inner_fn(self, inst)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1063, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 517, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/variables/torch.py", line 562, in call_function
if isinstance(args[0], ListVariable) and check_any_unspec(args[0]):
IndexError: list index out of range
from user code:
File "tensor_compile.py", line 4, in fn
res = torch.tensor(data=[[1., -1.]])
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "tensor_compile.py", line 9, in <module>
res_fwd = fn()
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 368, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
r = func(*args, **kwargs)
File "/home/jthakur/.pt_2_0/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 394, in _compile
raise InternalTorchDynamoError() from e
torch._dynamo.exc.InternalTorchDynamoError
Minified repro
import torch
def fn():
res = torch.tensor(data=[[1., -1.]])
return res
if __name__ == "__main__":
fn = torch.compile(fn)
res_fwd = fn()
print(res_fwd)
Versions
Name: torch
Version: 2.1.0.dev20230304+cpu
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: /home/jthakur/.pt_2_0/lib/python3.8/site-packages
Requires: filelock, jinja2, networkx, sympy, typing-extensions
Required-by: torchaudio, torchvision
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module