Skip to content

Commit 599f3da

Browse files
BowenBaowschin
authored andcommitted
Add helper function update_inputs_outputs_dims to tools (#2148)
* Add a helper function update_inputs_outputs_dims to tools * fix link to doc * newline at the end * add test for tools * doc props * nit * ci tests * ci tests 2 * accept shapes by dictionary inputs and add more error handling * Update onnx/tools/update_model_dims.py nit: rephrasing Co-Authored-By: Wei-Sheng Chin <[email protected]> * remove debug line * fix type annotation * fix annotation * fix annotation * fix annotation * fix flake8
1 parent 3e6382b commit 599f3da

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

docs/PythonAPIOverview.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,19 @@ import onnx.utils
214214
model = onnx.load('path/to/the/model.onnx')
215215
polished_model = onnx.utils.polish_model(model)
216216
```
217+
218+
## Tools
219+
### Updating Model's Inputs Outputs Dimension Sizes with Variable Length
220+
Function `update_inputs_outputs_dims` updates the dimension of the inputs and outputs of the model,
221+
to the provided values in the parameter. You could provide both static and dynamic dimension size,
222+
by using dim_param. For more information on static and dynamic dimension size, checkout [Tensor Shapes](IR.md#tensor-shapes).
223+
224+
The function runs model checker after the input/output sizes are updated.
225+
```python
226+
import onnx
227+
from onnx.tools import update_model_dims
228+
229+
model = onnx.load('path/to/the/model.onnx')
230+
# Here both 'seq', 'batch' and -1 are dynamic using dim_param.
231+
variable_length_model = update_model_dims.update_inputs_outputs_dims(model, {'input_name': ['seq', 'batch', 3, -1]}, {'output_name': ['seq', 'batch', 1, -1]})
232+
```

onnx/test/tools_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import unittest
7+
import onnx
8+
from onnx.tools import update_model_dims
9+
from onnx import helper, TensorProto
10+
11+
12+
class TestToolsFunctions(unittest.TestCase):
13+
def test_update_inputs_outputs_dim(self): # type: () -> None
14+
node_def = helper.make_node(
15+
"Conv",
16+
inputs=['x', 'W'],
17+
outputs=['y'],
18+
kernel_shape=[3, 3],
19+
strides=[2, 2],
20+
)
21+
graph_def = helper.make_graph(
22+
[node_def],
23+
'test',
24+
[helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]),
25+
helper.make_tensor_value_info('W', TensorProto.FLOAT, [1, 1, 3, 3])],
26+
[helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 2, 2])]
27+
)
28+
model_def = helper.make_model(graph_def, producer_name='test')
29+
updated_def = update_model_dims.update_inputs_outputs_dims(
30+
model_def,
31+
{
32+
"x": [1, 1, 'x1', -1],
33+
"W": [1, 1, 3, 3],
34+
},
35+
{
36+
"y": [1, 1, -1, -1],
37+
})
38+
onnx.checker.check_model(updated_def)
39+
self.assertEqual(updated_def.graph.input[0].type.tensor_type.shape.dim[2].dim_param, 'x1')
40+
self.assertEqual(updated_def.graph.input[0].type.tensor_type.shape.dim[3].dim_param, 'x_3')
41+
self.assertEqual(updated_def.graph.output[0].type.tensor_type.shape.dim[2].dim_param, 'y_2')
42+
self.assertEqual(updated_def.graph.output[0].type.tensor_type.shape.dim[3].dim_param, 'y_3')
43+
44+
45+
if __name__ == '__main__':
46+
unittest.main()

onnx/tools/update_model_dims.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
from six import string_types
7+
from typing import Any, List, Text, Dict, Set
8+
from onnx import ModelProto, ValueInfoProto
9+
10+
import onnx.checker
11+
12+
13+
def update_inputs_outputs_dims(model, input_dims, output_dims): # type: (ModelProto, Dict[Text, List[Any]], Dict[Text, List[Any]]) -> ModelProto
14+
"""
15+
This function updates the dimension sizes of the model's inputs and outputs to the values
16+
provided in input_dims and output_dims. if the dim value provided is negative, a unique dim_param
17+
will be set for that dimension.
18+
19+
Example. if we have the following shape for inputs and outputs:
20+
shape(input_1) = ('b', 3, 'w', 'h')
21+
shape(input_2) = ('b', 4)
22+
and shape(output) = ('b', 'd', 5)
23+
24+
The parameters can be provided as:
25+
input_dims = {
26+
"input_1": ['b', 3, 'w', 'h'],
27+
"input_2": ['b', 4],
28+
}
29+
output_dims = {
30+
"output": ['b', -1, 5]
31+
}
32+
33+
Putting it together:
34+
model = onnx.load('model.onnx')
35+
updated_model = update_inputs_outputs_dims(model, input_dims, output_dims)
36+
onnx.save(updated_model, 'model.onnx')
37+
"""
38+
dim_param_set = set() # type: Set[Text]
39+
40+
def init_dim_param_set(dim_param_set, value_infos): # type: (Set[Text], List[ValueInfoProto]) -> None
41+
for info in value_infos:
42+
shape = info.type.tensor_type.shape
43+
for dim in shape.dim:
44+
if dim.HasField('dim_param'):
45+
dim_param_set.add(dim.dim_param) # type: ignore
46+
47+
init_dim_param_set(dim_param_set, model.graph.input) # type: ignore
48+
init_dim_param_set(dim_param_set, model.graph.output) # type: ignore
49+
init_dim_param_set(dim_param_set, model.graph.value_info) # type: ignore
50+
51+
def update_dim(tensor, dim, j, name): # type: (ValueInfoProto, Any, int, Text) -> None
52+
dim_proto = tensor.type.tensor_type.shape.dim[j]
53+
if isinstance(dim, int):
54+
if dim >= 0:
55+
if dim_proto.HasField('dim_value') and dim_proto.dim_value != dim:
56+
raise ValueError('Unable to set dimension value to {} for axis {} of {}. Contradicts existing dimension value {}.'
57+
.format(dim, j, name, dim_proto.dim_value))
58+
dim_proto.dim_value = dim
59+
else:
60+
generated_dim_param = name + '_' + str(j)
61+
if generated_dim_param in dim_param_set:
62+
raise ValueError('Unable to generate unique dim_param for axis {} of {}. Please manually provide a dim_param value.'
63+
.format(j, name))
64+
dim_proto.dim_param = generated_dim_param
65+
elif isinstance(dim, string_types):
66+
dim_proto.dim_param = dim
67+
else:
68+
raise ValueError('Only int or str is accepted as dimension value, incorrect type: {}'.format(type(dim)))
69+
70+
for input in model.graph.input:
71+
input_name = input.name
72+
input_dim_arr = input_dims[input_name]
73+
for j, dim in enumerate(input_dim_arr):
74+
update_dim(input, dim, j, input_name)
75+
76+
for output in model.graph.output:
77+
output_name = output.name
78+
output_dim_arr = output_dims[output_name]
79+
for j, dim in enumerate(output_dim_arr):
80+
update_dim(output, dim, j, output_name)
81+
82+
onnx.checker.check_model(model)
83+
return model

0 commit comments

Comments
 (0)