-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Describe the bug
Say a=[0, 1, 2, ..., N-1] is passed to a 1-D Pad op with mode='reflect' and padding=[-(N-1), N-1]. ORT produces seg-fault or values that look like garbage memory (depending on N).
I am guessing that ORT performs the negative padding (i.e., slicing) before the reflected padding. So it will remove the first N-1 elements from a, at which point there are not enough elements to be "reflected".
I tried reading the code related to Pad. Though not able to fully understand it, the logic appears to verify my guess:
| SliceIterator<T> input(input_tensor, input_shape, input_starts, input_extents, {}); |
onnxruntime/onnxruntime/core/providers/cpu/tensor/pad.cc
Lines 421 to 427 in 0869f4f
| output = input.CopyInnermostAxisSolitaryInnerStep(output); | |
| int64_t prePad = reshaped_pad[inner_axis]; | |
| int64_t postPad = reshaped_pad[inner_axis + new_dims_count]; | |
| if (inner_no_pad_size == 1) { | |
| PadInnermostAxis(axisStart - prePad, axisStart + prePad, -1 /* inputDelta */, prePad); | |
| PadInnermostAxis(output, output - 2, -1 /* inputDelta */, postPad); |
L421 copies the after-slicing data into a new and smaller buffer
output and perform the reflected padding at L426-427. This smaller buffer has no enough elements to "reflect" in this case, and out-of-bound array accessing follows, causing the errors.
Urgency
None
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):Ubuntu 18.04
- ONNX Runtime installed from (source or binary): source
- ONNX Runtime version: 0869f4f
- Python version: 3.7.11
- Visual Studio version (if applicable):
- GCC/Compiler version (if compiling from source): clang version 11.1.0-++20211011094159+1fdec59bffc1-1
exp120211011214614.8 - CUDA/cuDNN version: 11.2
- GPU model and memory: RTX2080, 8GB
To Reproduce
Run this script:
import onnxruntime as ort
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
import numpy as np
import sys
N = int(sys.argv[1])
print('testing N=', N)
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [N])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [N])
pads = np.array([-(N - 1), (N - 1)], dtype=np.int64)
node = helper.make_node(
'Pad',
inputs=['x', 'pads'],
outputs=['y'],
mode='reflect'
)
graph_def = helper.make_graph(
[node], # nodes
'test-model', # name
[x], # inputs
[y], # outputs
initializer=[helper.make_tensor('pads', TensorProto.INT64, [2], pads)]
)
model_def = helper.make_model(graph_def, producer_name='onnx-example')
onnx.checker.check_model(model_def, full_check=True)
onnx.save(model_def, "output.onnx")
x = np.arange(N, dtype=np.float32)
print('input=')
print(x)
# print(f'input=[0,1,...,{N-1}]')
print(f'padding=[-{N-1}, {N-1}]')
print('output=')
y_ort = ort.InferenceSession("output.onnx", provider_options=[
"CPUExecutionProvider"]).run(["y"], {"x": x})[0]
print(y_ort)Expected behavior
It appears that ORT supports Pad op using both negative and positive paddings at the same time (and ONNX spec also doesn't seem to prohibit this).
- If that is the case, this sounds like a bug. Hope it can be an easy fix, e.g., passing the original buffer to the 2nd argument of
PadInnermostAxisseems to do the job, but I unfortunately don't know the code well enough to make the patch - If not, would be great to have a check that rejects mixed positive and negative padding, instead of segfault or silently producing incorrect values.
and the stack trace for segfault run:

Additional context
