Skip to content

[ONNX] The shape of PReLU weight is wrong #21271

@daquexian

Description

@daquexian

🐛 Bug

The shape of PReLU weight is incompatible with ONNX document.

The shape of weight should be [channels, 1, 1] according to ONNX document, however, the actual shape is [channels] instead.

It causes the exported ONNX model illegal. Related issue: daquexian/onnx-simplifier#7

To Reproduce

Code sample:

import torch
from torch import nn
import numpy as np
import onnx
import onnxruntime as rt


shape = (2, 16, 96, 96)

def generate_model():
    net = nn.PReLU(16)
    model_name = 'only_relu.onnx'
    dummy_input = torch.randn(*shape)
    torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])
    model = onnx.load(model_name)
    return model


def forward(model, inputs):
    sess = rt.InferenceSession(model.SerializeToString())
    outputs = [x.name for x in sess.get_outputs()]
    res = dict(zip(outputs, sess.run(outputs, inputs)))
    return res


forward(generate_model(), {'input': np.random.rand(*shape).astype(np.float32)})

Error:

RuntimeError: Method run failed due to: [ONNXRuntimeError] : 1 : GENERAL ERROR : /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:341 void onnxruntime::BroadcastIterator::Init(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 16 by 96

Expected behavior

The model runs correctly.

Environment

PyTorch version: 1.1.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Manjaro Linux
GCC version: (GCC) 8.3.0
CMake version: version 3.14.4

Python version: 3.7
Is CUDA available: No

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] torch==1.1.0
[pip3] torchvision==0.2.1
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions