Skip to content

[ONNX export] exporting model to onnx error when tensor.index_fill ops met dim=0 #139594

@a3213105

Description

@a3213105

🐛 Describe the bug

When exporting model to onnx, i got an error when tensor.index_fill op met dim=0 and tensor x shape is 1D

import torch

class IndexFillModel(torch.nn. Module) :
    def __init__ (self, dim_value):
        super().__init__()
        self.dim = torch.tensor(dim_value)

    def forward(self, x, index):
        print(f"in index_fill_, x={x}, dim={self.dim}, index={index}")
        return x.index_fill_(self.dim, index, -1)

model = IndexFillModel(0)
model.eval()
index = torch.tensor([1])
x = torch.tensor([4, 5, 6], dtype=torch.float)
print(f"x={x.shape}, index={index.shape}")
output = model(x.clone(), index)
print(f"model({x}, {index}) ={output}")
onnx_path=f"model.onnx"
torch.onnx.export(model, (x.clone(), index), f=onnx_path)
# error log 
x=torch.Size([3]), index=torch.Size([1])
in index_fill_, x=tensor([4., 5., 6.]), dim=0, index=tensor([1])
model(tensor([4., 5., 6.]), tensor([1])) =tensor([ 4., -1.,  6.])
in index_fill_, x=tensor([4., 5., 6.]), dim=0, index=tensor([1])
Torch IR graph at exception: graph(%inp.1 : Float(3, strides=[1], requires_grad=0, device=cpu),
      %inp : Long(1, strides=[1], requires_grad=0, device=cpu)):
  %155 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::resolve_conj(%inp.1), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
  %156 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::resolve_neg(%155), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
  %175 : Long(1, strides=[1], requires_grad=0, device=cpu) = aten::resolve_conj(%inp), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
  %176 : Long(1, strides=[1], requires_grad=0, device=cpu) = aten::resolve_neg(%175), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
  %182 : Long(device=cpu) = prim::Constant[value={0}]()
  %183 : Long(device=cpu) = prim::Constant[value={-1}](), scope: __main__.IndexFillModel::
  %180 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::index_fill_(%156, %182, %176, %183), scope: __main__.IndexFillModel:: # /home/sgui/test_index.py:10:0
  return (%180)

Traceback (most recent call last):
  File "/home/sgui/test_index.py", line 20, in <module>
    torch.onnx.export(model, (x.clone(), index), f=onnx_path)
  File "/home/sgui/pytorch/torch/onnx/__init__.py", line 370, in export
    export(
  File "/home/sgui/pytorch/torch/onnx/utils.py", line 495, in export
    _export(
  File "/home/sgui/pytorch/torch/onnx/utils.py", line 1418, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/sgui/pytorch/torch/onnx/utils.py", line 1052, in _model_to_graph
    graph = _optimize_graph(
  File "/home/sgui/pytorch/torch/onnx/utils.py", line 632, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/sgui/pytorch/torch/onnx/utils.py", line 1687, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/sgui/pytorch/torch/onnx/symbolic_opset11.py", line 948, in index_fill
    expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  File "/home/sgui/pytorch/torch/onnx/symbolic_helper.py", line 1336, in _index_fill_reshape_helper
    unsqueezed_index = _unsqueeze_helper(
  File "/home/sgui/pytorch/torch/onnx/symbolic_helper.py", line 809, in _unsqueeze_helper
    if _is_constant(axes_i[0]):
IndexError: list index out of range

Versions

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.0.100
[pip3] numpy==1.25.0
[pip3] nvidia-cublas-cu11==11.10.3.66
[pip3] nvidia-cuda-cupti-cu11==11.7.101
[pip3] nvidia-cuda-nvrtc-cu11==11.7.99
[pip3] nvidia-cuda-runtime-cu11==11.7.99
[pip3] nvidia-cudnn-cu11==8.5.0.96
[pip3] nvidia-cufft-cu11==10.9.0.58
[pip3] nvidia-curand-cu11==10.2.10.91
[pip3] nvidia-cusolver-cu11==11.4.0.1
[pip3] nvidia-cusparse-cu11==11.7.4.91
[pip3] nvidia-nccl-cu11==2.14.3
[pip3] nvidia-nvtx-cu11==11.7.91
[pip3] onnx==1.13.1
[pip3] onnxruntime==1.13.1
[pip3] optree==0.13.0
[pip3] pytorch-lightning==2.2.3
[pip3] torch==2.6.0a0+git86db2cd
[pip3] torchmetrics==1.3.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] intel-extension-for-pytorch 2.0.100 pypi_0 pypi
[conda] mkl-include 2025.0.0 pypi_0 pypi
[conda] mkl-static 2025.0.0 pypi_0 pypi
[conda] numpy 1.25.0 pypi_0 pypi
[conda] nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
[conda] nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
[conda] nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
[conda] nvidia-curand-cu11 10.2.10.91 pypi_0 pypi
[conda] nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi
[conda] nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi
[conda] nvidia-nccl-cu11 2.14.3 pypi_0 pypi
[conda] nvidia-nvtx-cu11 11.7.91 pypi_0 pypi
[conda] optree 0.13.0 pypi_0 pypi
[conda] pytorch-lightning 2.2.3 pypi_0 pypi
[conda] torch 2.6.0a0+git86db2cd dev_0
[conda] torchmetrics 1.3.2 pypi_0 pypi
[conda] torchvision 0.15.2 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: onnxRelated to torch.onnxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions