Skip to content

JIT graph to ONNX Loop cond variable should be tensor(bool) #17531

@heydavid525

Description

@heydavid525

🐛 Bug

ONNX Loop produced by ScriptModule and onnx._export does not conform to ONNX loop spec: The second input to ONNX loop--the conditional variable cond--should be tensor(bool), but I got tensor(int64).

To Reproduce

Run the following script (environment details below):

from __future__ import print_function
import torch
import onnx
from onnx.onnx_pb2 import TensorProto

class SimpleModel(torch.jit.ScriptModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        #self.xx = torch.zeros(2, 2)

    @torch.jit.script_method
    def forward(self, num_iter : int):
        x = torch.ones([2, 2], dtype=torch.float32)
        y = torch.ones(2, 2, dtype=torch.float32)
        v = torch.ones(2, 1, dtype=torch.float32)
        for i in range(num_iter):
            v = x * v
        return x, v

model = SimpleModel()

model_onnx = torch.onnx._export(model, torch.tensor(5), "simple_loop.onnx",
        verbose=True,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
        example_outputs=(torch.zeros(2,2), torch.zeros(2,1)))

prog = onnx.load("simple_loop.onnx")
print("%1 tensor type is int64? ", prog.graph.node[0].attribute[0].t.data_type
        == TensorProto.INT64)

produces the following ONNX program:

graph(%num_iter : Long()) {
  %1 : Long() = onnx::Constant[value={1}]()
  %2 : Float(2, 2) = onnx::Constant[value= 1  1  1  1 [ CPUFloatType{2,2} ]]()
  %3 : Float(2, 1) = onnx::Constant[value= 1  1 [ CPUFloatType{2,1} ]]()
  %4 : Tensor = onnx::Loop(%num_iter, %1, %3)
    block0(%i : int, %cond : Tensor, %7 : Tensor) {
      %8 : Tensor = onnx::Mul(%2, %7)
      -> (%1, %8)
    }
  return (%2, %4);
}

%1 is tensor of type TensorProto.INT64 (see the program print out). However, %1 is the second input to block0, which is required to be tensor(bool) by the ONNX spec.

Environment

PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.14.1
GCC version: Could not collect
CMake version: version 3.13.4

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.1
[pip3] torch==1.0.1.post2
[conda] Could not collect

Metadata

Metadata

Assignees

Labels

module: onnxRelated to torch.onnxoncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions