@@ -56,7 +56,7 @@ def set_training(model, mode):
5656def export (model , args , f , export_params = True , verbose = False , training = False ,
5757 input_names = None , output_names = None , aten = False , export_raw_ir = False ,
5858 operator_export_type = None , opset_version = None , _retain_param_name = True ,
59- do_constant_folding = False , example_outputs = None , strip_doc_string = True ):
59+ do_constant_folding = False , example_outputs = None , strip_doc_string = True , dynamic_axes = None ):
6060 r"""
6161 Export a model into ONNX format. This exporter runs your model
6262 once in order to get a trace of its execution to be exported;
@@ -117,6 +117,41 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
117117 strip_doc_string (bool, default True): if True, strips the field
118118 "doc_string" from the exported model, which information about the stack
119119 trace.
120+ example_outputs: example outputs of the model that is being exported.
121+ dynamic_axes (dict<string, dict<int, string>> or dict<string, list(int)>, default empty dict):
122+ a dictionary to specify dynamic axes of input/output, such that:
123+ - KEY: input and/or output names
124+ - VALUE: index of dynamic axes for given key and potentially the name to be used for
125+ exported dynamic axes. In general the value is defined according to one of the following
126+ ways or a combination of both:
127+
128+ (1). A list of integers specifiying the dynamic axes of provided input. In this scenario
129+ automated names will be generated and applied to dynamic axes of provided input/output
130+ during export.
131+
132+ OR (2). An inner dictionary that specifies a mapping FROM the index of dynamic axis in
133+ corresponding input/output TO the name that is desired to be applied on such axis of
134+ such input/output during export.
135+
136+ Example. if we have the following shape for inputs and outputs:
137+ shape(input_1) = ('b', 3, 'w', 'h')
138+ and shape(input_2) = ('b', 4)
139+ and shape(output) = ('b', 'd', 5)
140+
141+ Then dynamic axes can be defined either as:
142+ (a). ONLY INDICES:
143+ dynamic_axes = {'input_1':[0, 2, 3], 'input_2':[0], 'output':[0, 1]}
144+ where automatic names will be generated for exported dynamic axes
145+
146+ OR (b). INDICES WITH CORRESPONDING NAMES:
147+ dynamic_axes = {'input_1':{0:'batch', 1:'width', 2:'height'},
148+ 'input_2':{0:'batch'},
149+ 'output':{0:'batch', 1:'detections'}
150+ where provided names will be applied to exported dynamic axes
151+
152+ OR (c). MIXED MODE OF (a) and (b)
153+ dynamic_axes = {'input_1':[0, 2, 3], 'input_2':{0:'batch'}, 'output':[0,1]}
154+
120155 """
121156 if aten or export_raw_ir :
122157 assert operator_export_type is None
@@ -130,7 +165,7 @@ def export(model, args, f, export_params=True, verbose=False, training=False,
130165 _export (model , args , f , export_params , verbose , training , input_names , output_names ,
131166 operator_export_type = operator_export_type , opset_version = opset_version ,
132167 _retain_param_name = _retain_param_name , do_constant_folding = do_constant_folding ,
133- example_outputs = example_outputs , strip_doc_string = strip_doc_string )
168+ example_outputs = example_outputs , strip_doc_string = strip_doc_string , dynamic_axes = dynamic_axes )
134169
135170
136171# ONNX can't handle constants that are lists of tensors, which can
@@ -349,7 +384,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
349384 input_names = None , output_names = None , operator_export_type = OperatorExportTypes .ONNX ,
350385 export_type = ExportTypes .PROTOBUF_FILE , example_outputs = None , propagate = False ,
351386 opset_version = None , _retain_param_name = False , do_constant_folding = False ,
352- strip_doc_string = True ):
387+ strip_doc_string = True , dynamic_axes = None ):
353388 global __IN_ONNX_EXPORT
354389 assert __IN_ONNX_EXPORT is False
355390 __IN_ONNX_EXPORT = True
@@ -366,11 +401,17 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
366401
367402 # TODO: Don't allocate a in-memory string for the protobuf
368403 defer_weight_export = export_type is not ExportTypes .PROTOBUF_FILE
404+ if dynamic_axes is None :
405+ dynamic_axes = {}
406+
407+ _validate_dynamic_axes (dynamic_axes , model , input_names , output_names )
408+
369409 if export_params :
370- proto , export_map = graph ._export_onnx (params_dict , opset_version , defer_weight_export , operator_export_type ,
371- strip_doc_string )
410+ proto , export_map = graph ._export_onnx (
411+ params_dict , opset_version , dynamic_axes , defer_weight_export , operator_export_type , strip_doc_string )
372412 else :
373- proto , export_map = graph ._export_onnx ({}, opset_version , False , operator_export_type , strip_doc_string )
413+ proto , export_map = graph ._export_onnx (
414+ {}, opset_version , dynamic_axes , False , operator_export_type , strip_doc_string )
374415
375416 if export_type == ExportTypes .PROTOBUF_FILE :
376417 assert (len (export_map ) == 0 )
@@ -721,6 +762,45 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
721762 import torch .onnx .symbolic_registry as sym_registry
722763 sym_registry .register_op (op_name , symbolic_fn , ns , opset_version )
723764
765+ # This helper function ensures dynamic axes argument is following the expected format
766+ def _validate_dynamic_axes (dynamic_axes , model , input_names , output_names ):
767+ if len (dynamic_axes ) == 0 :
768+ return
769+
770+ if (hasattr (model , 'graph' )):
771+ # Extracting set of valid input/output names that shall be used for dynamic_axes
772+ if (input_names is None ) or len (input_names ) == 0 :
773+ input_names = [x .uniqueName () for x in model .graph .inputs ()]
774+ if (output_names is None ) or len (output_names ) == 0 :
775+ output_names = [y .uniqueName () for y in model .graph .outputs ()]
776+
777+ valid_names = set ()
778+ if input_names is not None :
779+ valid_names .add (x for x in input_names )
780+ if output_names is not None :
781+ valid_names .add (x for x in output_names )
782+
783+ # If dynamic axes are provided as a list rather than dictionary, they should
784+ # first get converted to a dictionary in expected format. If desired axes names
785+ # are not provided for dynamic axes, automatic names shall be generated for
786+ # provided dynamic axes of specified input/output
787+ for key , value in dynamic_axes .items ():
788+ if key not in valid_names :
789+ warnings .warn ("Provided key {} for dynamic axes is not a valid input/output name" .format (key ))
790+ if isinstance (value , list ):
791+ warnings .warn ('No names were found for specified dynamic axes of provided input.'
792+ 'Automatically generated names will be applied to each dynamic axes of input {}' .format (key ))
793+
794+ value_dict = {}
795+ for i , x in enumerate (value ):
796+ if not isinstance (x , int ):
797+ raise ValueError ("The type of axis index is expected to be an integer" )
798+ if x in value_dict :
799+ warnings .warn ('Duplicate dynamic axis index {} was provided for input {}.'
800+ .format (x , key ))
801+ else :
802+ value_dict [x ] = str (key ) + '_dynamic_axes_' + str (i + 1 )
803+ dynamic_axes [key ] = value_dict
724804
725805torch ._C .Graph .op = _graph_op
726806torch ._C .Graph .at = _graph_at
0 commit comments