-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
See #20335 for context, this seems related.
To Reproduce
Steps to reproduce the behavior:
Run the following script:
import torch as th
class Mod(th.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return th.cat(2*[x], dim=0)
#return th.cat((x, x), dim=0) <-- Unlike before, this will still cause the load below to fail.
class ScriptMod(th.jit.ScriptModule):
def __init__(self, mod):
super().__init__()
x = th.zeros(1, 3)
mod_fn = lambda : mod(x)
self.mod = th.jit.trace(mod_fn, tuple())
@th.jit.script_method
def forward(self):
return self.mod()
if __name__ == "__main__":
with th.no_grad():
cm = ScriptMod(Mod())
cm.save("mod.ptj")
cm = th.jit.load("mod.ptj") # <-- This will fail with the same error as in the original repro.Output:
Traceback (most recent call last):
File "test.py", line 26, in <module>
cm = th.jit.load("mod.ptj") # <---this will fail with the same error as in the original repro above.
File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/__init__.py", line 151, in load
torch._C.import_ir_module(module_lookup, f, map_location, _extra_files)
RuntimeError:
Arguments for call are not valid.
The following operator variants are available:
aten::cat(Tensor[] tensors, int dim=<default>) -> Tensor:
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'List[Tensor]'.
Empty lists default to List[Tensor]. Use torch.jit.annotate(List[my_type], []) to create an empty list of another type.
aten::cat(Tensor[] tensors, int dim=<default>, *, Tensor out) -> Tensor:
Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'List[Tensor]'.
Empty lists default to List[Tensor]. Use torch.jit.annotate(List[my_type], []) to create an empty list of another type.
The original call is:
at code/mod.py:3:8
op_version_set = 1
def forward(self) -> Tensor:
_0 = torch.cat([CONSTANTS.c0, CONSTANTS.c0], 0)
~~~~~~~~~ <--- HERE
return torch.sum(_0, dtype=None)
Compiled from code test.py(9): forward
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue