Skip to content

Commit 98e3aae

Browse files
pk-gfacebook-github-bot
authored andcommitted
Adding support for exporting models with variable length input/output to ONNX (#20034)
Summary: Proposal: https://gist.github.com/pk-g/cc45ff8c5891b5699bffd883a87f13ae?fbclid=IwAR17bRA7Fks4APoZRYiNa93UkLdoFCpRDuIYEx0lNVyPTyaDAShbEnytiQo Pull Request resolved: #20034 Reviewed By: zrphercule Differential Revision: D15606731 Pulled By: houseroad fbshipit-source-id: 247251e07b4893cb3f7a1287948b1f57aadb7851
1 parent ba2bdf8 commit 98e3aae

File tree

5 files changed

+146
-18
lines changed

5 files changed

+146
-18
lines changed

test/onnx/test_operators.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,29 @@ def test_conv(self):
248248
x = torch.ones(20, 16, 50, 40, requires_grad=True)
249249
self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x)
250250

251+
def test_conv_variable_length(self):
252+
x = torch.ones(5, 3, 6, 6, requires_grad=True)
253+
model = torch.nn.Conv2d(3, 2, 3)
254+
y = model(x)
255+
256+
dynamic_axes = {'input_1': [0, 2, 3], 'output_1': {0: 'output_1_variable_dim_0', 1: 'output_1_variable_dim_1'}}
257+
model_proto_name = 'conv2d.onnx'
258+
torch.onnx.export(model, x, model_proto_name, verbose=True, input_names=["input_1"], output_names=["output_1"],
259+
example_outputs=y, dynamic_axes=dynamic_axes)
260+
261+
import onnx
262+
onnx_model = onnx.load(model_proto_name)
263+
onnx.checker.check_model(onnx_model)
264+
265+
# Asserting the default dynamic axes names are generated when custom names are not provided
266+
assert(onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param == "input_1_dynamic_axes_1")
267+
assert(onnx_model.graph.input[0].type.tensor_type.shape.dim[2].dim_param == "input_1_dynamic_axes_2")
268+
assert(onnx_model.graph.input[0].type.tensor_type.shape.dim[3].dim_param == "input_1_dynamic_axes_3")
269+
270+
# Asserting the custom names are applied when provided
271+
assert(onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_param == "output_1_variable_dim_0")
272+
assert(onnx_model.graph.output[0].type.tensor_type.shape.dim[1].dim_param == "output_1_variable_dim_1")
273+
251274
def test_convtranspose(self):
252275
x = torch.ones(2, 3, 4, 5, requires_grad=True)
253276
self.assertONNX(nn.ConvTranspose2d(3, 3, 3, stride=3, bias=False,

torch/csrc/jit/export.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,17 @@ class EncoderBase {
129129
onnx::GraphProto* graph_proto,
130130
const std::shared_ptr<Graph>& graph,
131131
const std::map<std::string, at::Tensor>& initializers =
132-
std::map<std::string, at::Tensor>());
132+
std::map<std::string, at::Tensor>(),
133+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
134+
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>());
133135

134136
void EncodeBlock(
135137
onnx::GraphProto* graph_proto,
136138
const Block* block,
137139
const std::map<std::string, at::Tensor>& initializers =
138-
std::map<std::string, at::Tensor>());
140+
std::map<std::string, at::Tensor>(),
141+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
142+
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>());
139143

140144
virtual void EncodeTensor(
141145
onnx::TensorProto* tensor_proto,
@@ -149,7 +153,9 @@ class EncoderBase {
149153
virtual void EncodeValueInfo(
150154
onnx::GraphProto* graph_proto,
151155
onnx::ValueInfoProto* v,
152-
const Value* n);
156+
const Value* n,
157+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes =
158+
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>());
153159

154160
void AddAttribute(
155161
onnx::NodeProto* node_proto,
@@ -206,16 +212,24 @@ EncoderBase::EncoderBase(
206212
void EncoderBase::EncodeValueInfo(
207213
onnx::GraphProto* graph_proto,
208214
onnx::ValueInfoProto* v,
209-
const Value* n) {
210-
v->set_name(n->uniqueName());
215+
const Value* n,
216+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
217+
std::string name = n->uniqueName();
218+
v->set_name(name);
211219
if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
212220
onnx::TypeProto* t = v->mutable_type();
213221
onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
214222
onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
215223
const std::vector<std::int64_t>& sizes = node_type->sizes();
216224
for (size_t i = 0; i < sizes.size(); i++) {
217225
shape->add_dim();
218-
shape->mutable_dim(i)->set_dim_value(sizes[i]);
226+
if ((dynamic_axes.find(name) != dynamic_axes.end()) &&
227+
(dynamic_axes.at(name).find(i) != dynamic_axes.at(name).end())){
228+
shape->mutable_dim(i)->set_dim_param(dynamic_axes.at(name).at(i));
229+
}
230+
else{
231+
shape->mutable_dim(i)->set_dim_value(sizes[i]);
232+
}
219233
}
220234
tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
221235
} else if (BoolTypePtr node_type = n->type()->cast<BoolType>()) {
@@ -228,14 +242,16 @@ void EncoderBase::EncodeValueInfo(
228242
void EncoderBase::EncodeGraph(
229243
onnx::GraphProto* graph_proto,
230244
const std::shared_ptr<Graph>& graph,
231-
const std::map<std::string, at::Tensor>& initializers) {
232-
EncodeBlock(graph_proto, graph->block(), initializers);
245+
const std::map<std::string, at::Tensor>& initializers,
246+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
247+
EncodeBlock(graph_proto, graph->block(), initializers, dynamic_axes);
233248
}
234249

235250
void EncoderBase::EncodeBlock(
236251
onnx::GraphProto* graph_proto,
237252
const Block* block,
238-
const std::map<std::string, at::Tensor>& initializers) {
253+
const std::map<std::string, at::Tensor>& initializers,
254+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes) {
239255
AT_ASSERT(graph_proto != nullptr);
240256
std::string block_name = "torch-jit-export";
241257
if (num_blocks_) {
@@ -246,11 +262,11 @@ void EncoderBase::EncodeBlock(
246262

247263
for (auto input : block->inputs()) {
248264
onnx::ValueInfoProto* v = graph_proto->add_input();
249-
EncodeValueInfo(graph_proto, v, input);
265+
EncodeValueInfo(graph_proto, v, input, dynamic_axes);
250266
}
251267
for (auto output : block->outputs()) {
252268
onnx::ValueInfoProto* v = graph_proto->add_output();
253-
EncodeValueInfo(graph_proto, v, output);
269+
EncodeValueInfo(graph_proto, v, output, dynamic_axes);
254270
}
255271
for (auto node : block->nodes()) {
256272
bool is_raw_export =
@@ -404,6 +420,7 @@ class GraphEncoder : public EncoderBase {
404420
int64_t onnx_opset_version,
405421
onnx_torch::OperatorExportTypes operator_export_type,
406422
const std::map<std::string, at::Tensor>& initializers,
423+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
407424
bool defer_weight_export,
408425
bool strip_doc);
409426

@@ -426,6 +443,7 @@ GraphEncoder::GraphEncoder(
426443
int64_t onnx_opset_version,
427444
onnx_torch::OperatorExportTypes operator_export_type,
428445
const std::map<std::string, at::Tensor>& initializers,
446+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
429447
bool defer_weight_export,
430448
bool strip_doc)
431449
: EncoderBase(operator_export_type, strip_doc),
@@ -438,7 +456,7 @@ GraphEncoder::GraphEncoder(
438456
// This is the version of ONNX operator set we are targeting
439457
imp->set_version(onnx_opset_version);
440458

441-
EncodeGraph(model_proto_.mutable_graph(), graph, initializers);
459+
EncodeGraph(model_proto_.mutable_graph(), graph, initializers, dynamic_axes);
442460

443461
for (const std::string& domain : domains_) {
444462
auto* opset = model_proto_.add_opset_import();
@@ -1112,6 +1130,7 @@ std::string pretty_print_onnx(
11121130
onnx_opset_version,
11131131
operator_export_type,
11141132
initializers,
1133+
std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>{},
11151134
defer_weight_export,
11161135
true);
11171136
if (google_printer) {
@@ -1129,6 +1148,7 @@ std::tuple<std::string, RawDataExportMap> export_onnx(
11291148
const std::shared_ptr<Graph>& graph,
11301149
const std::map<std::string, at::Tensor>& initializers,
11311150
int64_t onnx_opset_version,
1151+
const std::unordered_map<std::string, std::unordered_map<std::int64_t, std::string>>& dynamic_axes,
11321152
bool defer_weight_export,
11331153
::torch::onnx::OperatorExportTypes operator_export_type,
11341154
bool strip_doc_string) {
@@ -1137,6 +1157,7 @@ std::tuple<std::string, RawDataExportMap> export_onnx(
11371157
onnx_opset_version,
11381158
operator_export_type,
11391159
initializers,
1160+
dynamic_axes,
11401161
defer_weight_export,
11411162
strip_doc_string);
11421163
return std::make_tuple(

torch/csrc/jit/export.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
2525
const std::shared_ptr<Graph>& graph,
2626
const std::map<std::string, at::Tensor>& initializers,
2727
int64_t onnx_opset_version,
28+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
2829
bool defer_weight_export = false,
2930
::torch::onnx::OperatorExportTypes operator_export_type =
3031
::torch::onnx::OperatorExportTypes::ONNX,

torch/csrc/jit/python_ir.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ void initPythonIRBindings(PyObject* module_) {
231231
[](const std::shared_ptr<Graph> g,
232232
const std::map<std::string, at::Tensor>& initializers,
233233
int64_t onnx_opset_version,
234+
const std::unordered_map<std::string, std::unordered_map<int64_t, std::string>>& dynamic_axes,
234235
bool defer_weight_export,
235236
::torch::onnx::OperatorExportTypes operator_export_type,
236237
bool strip_doc_string) {
@@ -240,6 +241,7 @@ void initPythonIRBindings(PyObject* module_) {
240241
g,
241242
initializers,
242243
onnx_opset_version,
244+
dynamic_axes,
243245
defer_weight_export,
244246
operator_export_type,
245247
strip_doc_string);
@@ -259,6 +261,7 @@ void initPythonIRBindings(PyObject* module_) {
259261
},
260262
py::arg("initializers"),
261263
py::arg("onnx_opset_version") = 0,
264+
py::arg("dynamic_axes"),
262265
py::arg("defer_weight_export") = false,
263266
py::arg("operator_export_type") =
264267
::torch::onnx::OperatorExportTypes::ONNX,

torch/onnx/utils.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def set_training(model, mode):
5656
def export(model, args, f, export_params=True, verbose=False, training=False,
5757
input_names=None, output_names=None, aten=False, export_raw_ir=False,
5858
operator_export_type=None, opset_version=None, _retain_param_name=True,
59-
do_constant_folding=False, example_outputs=None, strip_doc_string=True):
59+
do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None):
6060
r"""
6161
Export a model into ONNX format. This exporter runs your model
6262
once in order to get a trace of its execution to be exported;
@@ -117,6 +117,41 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
117117
strip_doc_string (bool, default True): if True, strips the field
118118
"doc_string" from the exported model, which information about the stack
119119
trace.
120+
example_outputs: example outputs of the model that is being exported.
121+
dynamic_axes (dict<string, dict<int, string>> or dict<string, list(int)>, default empty dict):
122+
a dictionary to specify dynamic axes of input/output, such that:
123+
- KEY: input and/or output names
124+
- VALUE: index of dynamic axes for given key and potentially the name to be used for
125+
exported dynamic axes. In general the value is defined according to one of the following
126+
ways or a combination of both:
127+
128+
(1). A list of integers specifiying the dynamic axes of provided input. In this scenario
129+
automated names will be generated and applied to dynamic axes of provided input/output
130+
during export.
131+
132+
OR (2). An inner dictionary that specifies a mapping FROM the index of dynamic axis in
133+
corresponding input/output TO the name that is desired to be applied on such axis of
134+
such input/output during export.
135+
136+
Example. if we have the following shape for inputs and outputs:
137+
shape(input_1) = ('b', 3, 'w', 'h')
138+
and shape(input_2) = ('b', 4)
139+
and shape(output) = ('b', 'd', 5)
140+
141+
Then dynamic axes can be defined either as:
142+
(a). ONLY INDICES:
143+
dynamic_axes = {'input_1':[0, 2, 3], 'input_2':[0], 'output':[0, 1]}
144+
where automatic names will be generated for exported dynamic axes
145+
146+
OR (b). INDICES WITH CORRESPONDING NAMES:
147+
dynamic_axes = {'input_1':{0:'batch', 1:'width', 2:'height'},
148+
'input_2':{0:'batch'},
149+
'output':{0:'batch', 1:'detections'}
150+
where provided names will be applied to exported dynamic axes
151+
152+
OR (c). MIXED MODE OF (a) and (b)
153+
dynamic_axes = {'input_1':[0, 2, 3], 'input_2':{0:'batch'}, 'output':[0,1]}
154+
120155
"""
121156
if aten or export_raw_ir:
122157
assert operator_export_type is None
@@ -130,7 +165,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
130165
_export(model, args, f, export_params, verbose, training, input_names, output_names,
131166
operator_export_type=operator_export_type, opset_version=opset_version,
132167
_retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
133-
example_outputs=example_outputs, strip_doc_string=strip_doc_string)
168+
example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
134169

135170

136171
# ONNX can't handle constants that are lists of tensors, which can
@@ -349,7 +384,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
349384
input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
350385
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
351386
opset_version=None, _retain_param_name=False, do_constant_folding=False,
352-
strip_doc_string=True):
387+
strip_doc_string=True, dynamic_axes=None):
353388
global __IN_ONNX_EXPORT
354389
assert __IN_ONNX_EXPORT is False
355390
__IN_ONNX_EXPORT = True
@@ -366,11 +401,17 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
366401

367402
# TODO: Don't allocate a in-memory string for the protobuf
368403
defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
404+
if dynamic_axes is None:
405+
dynamic_axes = {}
406+
407+
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
408+
369409
if export_params:
370-
proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type,
371-
strip_doc_string)
410+
proto, export_map = graph._export_onnx(
411+
params_dict, opset_version, dynamic_axes, defer_weight_export, operator_export_type, strip_doc_string)
372412
else:
373-
proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type, strip_doc_string)
413+
proto, export_map = graph._export_onnx(
414+
{}, opset_version, dynamic_axes, False, operator_export_type, strip_doc_string)
374415

375416
if export_type == ExportTypes.PROTOBUF_FILE:
376417
assert(len(export_map) == 0)
@@ -721,6 +762,45 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
721762
import torch.onnx.symbolic_registry as sym_registry
722763
sym_registry.register_op(op_name, symbolic_fn, ns, opset_version)
723764

765+
# This helper function ensures dynamic axes argument is following the expected format
766+
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
767+
if len(dynamic_axes) == 0:
768+
return
769+
770+
if(hasattr(model, 'graph')):
771+
# Extracting set of valid input/output names that shall be used for dynamic_axes
772+
if (input_names is None) or len(input_names) == 0:
773+
input_names = [x.uniqueName() for x in model.graph.inputs()]
774+
if (output_names is None) or len(output_names) == 0:
775+
output_names = [y.uniqueName() for y in model.graph.outputs()]
776+
777+
valid_names = set()
778+
if input_names is not None:
779+
valid_names.add(x for x in input_names)
780+
if output_names is not None:
781+
valid_names.add(x for x in output_names)
782+
783+
# If dynamic axes are provided as a list rather than dictionary, they should
784+
# first get converted to a dictionary in expected format. If desired axes names
785+
# are not provided for dynamic axes, automatic names shall be generated for
786+
# provided dynamic axes of specified input/output
787+
for key, value in dynamic_axes.items():
788+
if key not in valid_names:
789+
warnings.warn("Provided key {} for dynamic axes is not a valid input/output name".format(key))
790+
if isinstance(value, list):
791+
warnings.warn('No names were found for specified dynamic axes of provided input.'
792+
'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))
793+
794+
value_dict = {}
795+
for i, x in enumerate(value):
796+
if not isinstance(x, int):
797+
raise ValueError("The type of axis index is expected to be an integer")
798+
if x in value_dict:
799+
warnings.warn('Duplicate dynamic axis index {} was provided for input {}.'
800+
.format(x, key))
801+
else:
802+
value_dict[x] = str(key) + '_dynamic_axes_' + str(i + 1)
803+
dynamic_axes[key] = value_dict
724804

725805
torch._C.Graph.op = _graph_op
726806
torch._C.Graph.at = _graph_at

0 commit comments

Comments
 (0)