Skip to content

[JIT] Shape inference broken for torch.arange under JIT tracer #11829

@neerajprad

Description

@neerajprad

Issue description

It seems like torch.arange fails to generalize to other tensor shapes when the end argument provided is dependent on the input size. cc. @apaszke, @zou3519, @soumith. Since this is being used by a few distribution methods, they may provide incorrect results when JITed.

Code example

The following fails to generalize to other tensor sizes:

In [16]: def fn(x):
    ...:     return torch.arange(x.shape[0])
    ...:
    ...:

In [17]: compiled = torch.jit.trace(fn, torch.ones(3))

In [18]: compiled(torch.ones(5))
Out[18]: tensor([0, 1, 2])  # wrong result

@fritzo suggested the following work-around in the meantime:

In [22]: def fn_alt(x):
    ...:     return torch.cumsum(torch.ones(x.shape[0]), 0, dtype=torch.long) - 1
    ...:
    ...:
In [24]: compiled = torch.jit.trace(fn_alt, torch.ones(3))

In [25]: compiled(torch.ones(5))
Out[25]: tensor([0, 1, 2, 3, 4])

System Info

Collecting environment information...
PyTorch version: 0.5.0a0+6660a12
Is debug build: Yes
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.3
GCC version: Could not collect
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] numpy (1.15.0)
[pip] torch (0.5.0a0+6c3792b, /Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages)
[pip] torchfile (0.1.0)
[pip] torchvision (0.2.1)
[conda] torch                     0.5.0a0+2431eac           <pip>
[conda] torch                     0.5.0a0+6c3792b           <pip>
[conda] torch                     0.5.0a0+6660a12           <pip>
[conda] torch                     0.5.0a0+35d52db           <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions