-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
Issue description
We are using the torch.einsum operation in many Pyro models (e.g. HMM), and it appears that the JIT tracer does not currently support the einsum operation. cc. @t-vi, @apaszke, @fritzo.
Code example
The following example:
@torch.jit.trace(torch.ones(2, 3), torch.ones(2, 3))
def fn(x, y):
return torch.einsum('ab,ab->b', [x, y])
fn(torch.ones(2, 3), torch.ones(2, 3))throws an error on a recent PyTorch commit:
Traceback (most recent call last):
File "examples/ein.py", line 4, in <module>
@torch.jit.trace(torch.ones(2, 3), torch.ones(2, 3))
File "/Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages/torch/jit/__init__.py", line 290, in wrapper
module._create_method_from_trace('forward', func, tuple(args))
File "examples/ein.py", line 6, in fn
return torch.einsum('ab,ab->b', [x, y])
File "/Users/npradhan/miniconda2/envs/pytorch-master/lib/python3.6/site-packages/torch/functional.py", line 239, in einsum
return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Found an unsupported argument type in the JIT tracer. File a bug report.
System Info
$ python collect_env.py
Collecting environment information...
PyTorch version: 0.5.0a0+72f91b1
Is debug build: Yes
CUDA used to build PyTorch: None
OS: Mac OSX 10.13.3
GCC version: Could not collect
CMake version: version 3.12.0
Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip] numpy (1.15.0)
[pip] torch (0.5.0a0+e449a27)
[pip] torchfile (0.1.0)
[pip] torchvision (0.2.1)
[conda] torch 0.5.0a0+e449a27 <pip>
[conda] torch 0.5.0a0+cb32e38 <pip>
[conda] torch 0.5.0a0+c98d748 <pip>
[conda] torch 0.5.0a0+64a6003 <pip>
[conda] torch 0.5.0a0+72f91b1 <pip>
[conda] torch 0.5.0a0+6456b94 <pip>
[conda] torch 0.5.0a0+cd53b78 <pip>
[conda] torchfile 0.1.0 <pip>
[conda] torchvision 0.2.1 <pip>
fritzo and Ark-kun
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue