-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
ONNX Loop produced by ScriptModule and onnx._export does not conform to ONNX loop spec: The second input to ONNX loop--the conditional variable cond--should be tensor(bool), but I got tensor(int64).
To Reproduce
Run the following script (environment details below):
from __future__ import print_function
import torch
import onnx
from onnx.onnx_pb2 import TensorProto
class SimpleModel(torch.jit.ScriptModule):
def __init__(self):
super(SimpleModel, self).__init__()
#self.xx = torch.zeros(2, 2)
@torch.jit.script_method
def forward(self, num_iter : int):
x = torch.ones([2, 2], dtype=torch.float32)
y = torch.ones(2, 2, dtype=torch.float32)
v = torch.ones(2, 1, dtype=torch.float32)
for i in range(num_iter):
v = x * v
return x, v
model = SimpleModel()
model_onnx = torch.onnx._export(model, torch.tensor(5), "simple_loop.onnx",
verbose=True,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
example_outputs=(torch.zeros(2,2), torch.zeros(2,1)))
prog = onnx.load("simple_loop.onnx")
print("%1 tensor type is int64? ", prog.graph.node[0].attribute[0].t.data_type
== TensorProto.INT64)
produces the following ONNX program:
graph(%num_iter : Long()) {
%1 : Long() = onnx::Constant[value={1}]()
%2 : Float(2, 2) = onnx::Constant[value= 1 1 1 1 [ CPUFloatType{2,2} ]]()
%3 : Float(2, 1) = onnx::Constant[value= 1 1 [ CPUFloatType{2,1} ]]()
%4 : Tensor = onnx::Loop(%num_iter, %1, %3)
block0(%i : int, %cond : Tensor, %7 : Tensor) {
%8 : Tensor = onnx::Mul(%2, %7)
-> (%1, %8)
}
return (%2, %4);
}
%1 is tensor of type TensorProto.INT64 (see the program print out). However, %1 is the second input to block0, which is required to be tensor(bool) by the ONNX spec.
Environment
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: None
OS: Mac OSX 10.14.1
GCC version: Could not collect
CMake version: version 3.13.4
Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Versions of relevant libraries:
[pip3] numpy==1.16.1
[pip3] torch==1.0.1.post2
[conda] Could not collect