-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
When exporting model to onnx, i got an error when tensor.index_fill op met dim=0 and tensor x shape is 1D
import torch
class IndexFillModel(torch.nn. Module) :
def __init__ (self, dim_value):
super().__init__()
self.dim = torch.tensor(dim_value)
def forward(self, x, index):
print(f"in index_fill_, x={x}, dim={self.dim}, index={index}")
return x.index_fill_(self.dim, index, -1)
model = IndexFillModel(0)
model.eval()
index = torch.tensor([1])
x = torch.tensor([4, 5, 6], dtype=torch.float)
print(f"x={x.shape}, index={index.shape}")
output = model(x.clone(), index)
print(f"model({x}, {index}) ={output}")
onnx_path=f"model.onnx"
torch.onnx.export(model, (x.clone(), index), f=onnx_path)
# error log
x=torch.Size([3]), index=torch.Size([1])
in index_fill_, x=tensor([4., 5., 6.]), dim=0, index=tensor([1])
model(tensor([4., 5., 6.]), tensor([1])) =tensor([ 4., -1., 6.])
in index_fill_, x=tensor([4., 5., 6.]), dim=0, index=tensor([1])
Torch IR graph at exception: graph(%inp.1 : Float(3, strides=[1], requires_grad=0, device=cpu),
%inp : Long(1, strides=[1], requires_grad=0, device=cpu)):
%155 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::resolve_conj(%inp.1), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
%156 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::resolve_neg(%155), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
%175 : Long(1, strides=[1], requires_grad=0, device=cpu) = aten::resolve_conj(%inp), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
%176 : Long(1, strides=[1], requires_grad=0, device=cpu) = aten::resolve_neg(%175), scope: __main__.IndexFillModel:: # /home/sgui/pytorch/torch/_tensor_str.py:261:0
%182 : Long(device=cpu) = prim::Constant[value={0}]()
%183 : Long(device=cpu) = prim::Constant[value={-1}](), scope: __main__.IndexFillModel::
%180 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::index_fill_(%156, %182, %176, %183), scope: __main__.IndexFillModel:: # /home/sgui/test_index.py:10:0
return (%180)
Traceback (most recent call last):
File "/home/sgui/test_index.py", line 20, in <module>
torch.onnx.export(model, (x.clone(), index), f=onnx_path)
File "/home/sgui/pytorch/torch/onnx/__init__.py", line 370, in export
export(
File "/home/sgui/pytorch/torch/onnx/utils.py", line 495, in export
_export(
File "/home/sgui/pytorch/torch/onnx/utils.py", line 1418, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/home/sgui/pytorch/torch/onnx/utils.py", line 1052, in _model_to_graph
graph = _optimize_graph(
File "/home/sgui/pytorch/torch/onnx/utils.py", line 632, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/home/sgui/pytorch/torch/onnx/utils.py", line 1687, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
File "/home/sgui/pytorch/torch/onnx/symbolic_opset11.py", line 948, in index_fill
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
File "/home/sgui/pytorch/torch/onnx/symbolic_helper.py", line 1336, in _index_fill_reshape_helper
unsqueezed_index = _unsqueeze_helper(
File "/home/sgui/pytorch/torch/onnx/symbolic_helper.py", line 809, in _unsqueeze_helper
if _is_constant(axes_i[0]):
IndexError: list index out of range
Versions
Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.0.100
[pip3] numpy==1.25.0
[pip3] nvidia-cublas-cu11==11.10.3.66
[pip3] nvidia-cuda-cupti-cu11==11.7.101
[pip3] nvidia-cuda-nvrtc-cu11==11.7.99
[pip3] nvidia-cuda-runtime-cu11==11.7.99
[pip3] nvidia-cudnn-cu11==8.5.0.96
[pip3] nvidia-cufft-cu11==10.9.0.58
[pip3] nvidia-curand-cu11==10.2.10.91
[pip3] nvidia-cusolver-cu11==11.4.0.1
[pip3] nvidia-cusparse-cu11==11.7.4.91
[pip3] nvidia-nccl-cu11==2.14.3
[pip3] nvidia-nvtx-cu11==11.7.91
[pip3] onnx==1.13.1
[pip3] onnxruntime==1.13.1
[pip3] optree==0.13.0
[pip3] pytorch-lightning==2.2.3
[pip3] torch==2.6.0a0+git86db2cd
[pip3] torchmetrics==1.3.2
[pip3] torchvision==0.15.2
[pip3] triton==2.0.0
[conda] intel-extension-for-pytorch 2.0.100 pypi_0 pypi
[conda] mkl-include 2025.0.0 pypi_0 pypi
[conda] mkl-static 2025.0.0 pypi_0 pypi
[conda] numpy 1.25.0 pypi_0 pypi
[conda] nvidia-cublas-cu11 11.10.3.66 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu11 11.7.101 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu11 11.7.99 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu11 11.7.99 pypi_0 pypi
[conda] nvidia-cudnn-cu11 8.5.0.96 pypi_0 pypi
[conda] nvidia-cufft-cu11 10.9.0.58 pypi_0 pypi
[conda] nvidia-curand-cu11 10.2.10.91 pypi_0 pypi
[conda] nvidia-cusolver-cu11 11.4.0.1 pypi_0 pypi
[conda] nvidia-cusparse-cu11 11.7.4.91 pypi_0 pypi
[conda] nvidia-nccl-cu11 2.14.3 pypi_0 pypi
[conda] nvidia-nvtx-cu11 11.7.91 pypi_0 pypi
[conda] optree 0.13.0 pypi_0 pypi
[conda] pytorch-lightning 2.2.3 pypi_0 pypi
[conda] torch 2.6.0a0+git86db2cd dev_0
[conda] torchmetrics 1.3.2 pypi_0 pypi
[conda] torchvision 0.15.2 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi