Skip to content

cat() call in ScriptModule w/constant arguments causes loading of saved module to fail #22809

@HapeMask

Description

@HapeMask

🐛 Bug

See #20335 for context, this seems related.

To Reproduce

Steps to reproduce the behavior:
Run the following script:

import torch as th

class Mod(th.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return th.cat(2*[x], dim=0)
        #return th.cat((x, x), dim=0) <-- Unlike before, this will still cause the load below to fail.

class ScriptMod(th.jit.ScriptModule):
    def __init__(self, mod):
        super().__init__()
        x = th.zeros(1, 3)
        mod_fn = lambda : mod(x)
        self.mod = th.jit.trace(mod_fn, tuple())

    @th.jit.script_method
    def forward(self):
        return self.mod()

if __name__ == "__main__":
    with th.no_grad():
        cm = ScriptMod(Mod())
        cm.save("mod.ptj")
        cm = th.jit.load("mod.ptj") # <-- This will fail with the same error as in the original repro.

Output:

Traceback (most recent call last):
  File "test.py", line 26, in <module>
    cm = th.jit.load("mod.ptj") # <---this will fail with the same error as in the original repro above.
  File "/mnt/home/gbschwartz/anaconda/envs/py3_newpytorch_cuda10/lib/python3.7/site-packages/torch/jit/__init__.py", line 151, in load
    torch._C.import_ir_module(module_lookup, f, map_location, _extra_files)
RuntimeError: 
Arguments for call are not valid.
The following operator variants are available:
  
  aten::cat(Tensor[] tensors, int dim=<default>) -> Tensor:
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'List[Tensor]'.
  Empty lists default to List[Tensor]. Use torch.jit.annotate(List[my_type], []) to create an empty list of another type.
  
  aten::cat(Tensor[] tensors, int dim=<default>, *, Tensor out) -> Tensor:
  Expected a value of type 'List[Tensor]' for argument 'tensors' but instead found type 'List[Tensor]'.
  Empty lists default to List[Tensor]. Use torch.jit.annotate(List[my_type], []) to create an empty list of another type.

The original call is:
at code/mod.py:3:8
op_version_set = 1
def forward(self) -> Tensor:
  _0 = torch.cat([CONSTANTS.c0, CONSTANTS.c0], 0)
       ~~~~~~~~~ <--- HERE
  return torch.sum(_0, dtype=None)
Compiled from code test.py(9): forward

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions