Skip to content

torch.onnx._export does not support tensor sum with multiple dims #22066

@yil8

Description

@yil8

🐛 Bug

To Reproduce

import torch
from torch import nn

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()

    def forward(self, x):
        return x.sum(dim=(2, 3), keepdim=True)

model = Test()

x = torch.zeros((16, 3, 256, 256))

torch.onnx._export(model, x, "test.onnx", verbose=True)

Error message

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-1-4158e5bbe3ab> in <module>
     13 x = torch.zeros((16, 3, 256, 256))
     14 
---> 15 torch.onnx._export(model, x, "test.onnx", verbose=True)

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/__init__.py in _export(*args, **kwargs)
     20 def _export(*args, **kwargs):
     21     from torch.onnx import utils
---> 22     return utils._export(*args, **kwargs)
     23 
     24 

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate)
    279                                                training, input_names,
    280                                                output_names, operator_export_type,
--> 281                                                example_outputs, propagate)
    282 
    283     # TODO: Don't allocate a in-memory string for the protobuf

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, f, verbose, training, input_names, output_names, operator_export_type, example_outputs, propagate)
    225         params = list(_unique_state_dict(model).values())
    226 
--> 227     graph = _optimize_graph(graph, operator_export_type)
    228 
    229     # NB: ONNX requires complete information about output types, which might be

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/utils.py in _optimize_graph(graph, operator_export_type)
    153 
    154     if operator_export_type != OperatorExportTypes.RAW:
--> 155         graph = torch._C._jit_pass_onnx(graph, operator_export_type)
    156         torch._C._jit_pass_lint(graph)
    157         torch._C._jit_pass_onnx_peephole(graph)

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/__init__.py in _run_symbolic_function(*args, **kwargs)
     50 def _run_symbolic_function(*args, **kwargs):
     51     from torch.onnx import utils
---> 52     return utils._run_symbolic_function(*args, **kwargs)
     53 
     54 

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/utils.py in _run_symbolic_function(g, n, inputs, env, operator_export_type)
    502                     return None
    503                 fn = getattr(torch.onnx.symbolic, op_name)
--> 504                 return fn(g, *inputs, **attrs)
    505 
    506         elif ns == "prim":

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/symbolic.py in symbolic(g, self, dim, keepdim)
    329         else:
    330             # dim-reduce path
--> 331             dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
    332             return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
    333     return symbolic

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/symbolic.py in _get_const(value, desc, arg_name)
     73     if _is_value(value) and value.node().kind() != 'onnx::Constant':
     74         raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
---> 75     return _parse_arg(value, desc)
     76 
     77 

~/env/py3.6/lib/python3.6/site-packages/torch/onnx/symbolic.py in _parse_arg(value, desc)
     46     tval = value.node()['value']
     47     if desc == 'i':
---> 48         return int(tval)
     49     elif desc == 'f':
     50         return float(tval)

ValueError: only one element tensors can be converted to Python scalars

Environment

  • PyTorch Version : 1.0.1.post2
  • OS: Ubuntu 18.04.2 LTS
  • How you installed PyTorch : pip install torch
  • Python version: 3.6.7
  • CUDA/cuDNN version: 10.1/7.5
  • GPU models and configuration: NVIDIA 1080Ti

Metadata

Metadata

Labels

module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions