Skip to content

[JIT] torch.einsum not supported by JIT tracer #11157

@neerajprad

Description

@neerajprad

Issue description

We are using the torch.einsum operation in many Pyro models (e.g. HMM), and it appears that the JIT tracer does not currently support the einsum operation. cc. @t-vi, @apaszke, @fritzo.

Code example

The following example:

@torch.jit.trace(torch.ones(2, 3), torch.ones(2, 3))
def fn(x, y):
    return torch.einsum('ab,ab->b', [x, y])


fn(torch.ones(2, 3), torch.ones(2, 3))

throws an error on a recent PyTorch commit:

Traceback (most recent call last):
  File "examples/ein.py", line 4, in <module>
    @torch.jit.trace(torch.ones(2, 3), torch.ones(2, 3))
  File "/Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages/torch/jit/__init__.py", line 290, in wrapper
    module._create_method_from_trace('forward', func, tuple(args))
  File "examples/ein.py", line 6, in fn
    return torch.einsum('ab,ab->b', [x, y])
  File "/Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages/torch/functional.py", line 239, in einsum
    return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Found an unsupported argument type in the JIT tracer. File a bug report.

System Info

  $ python collect_env.py
Collecting environment information...
PyTorch version: 0.5.0a0+72f91b1
Is debug build: Yes
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.3
GCC version: Could not collect
CMake version: version 3.12.0

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] numpy (1.15.0)
[pip] torch (0.5.0a0+e449a27)
[pip] torchfile (0.1.0)
[pip] torchvision (0.2.1)
[conda] torch                     0.5.0a0+e449a27           <pip>
[conda] torch                     0.5.0a0+cb32e38           <pip>
[conda] torch                     0.5.0a0+c98d748           <pip>
[conda] torch                     0.5.0a0+64a6003           <pip>
[conda] torch                     0.5.0a0+72f91b1           <pip>
[conda] torch                     0.5.0a0+6456b94           <pip>
[conda] torch                     0.5.0a0+cd53b78           <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions