Skip to content

Commit f108953

Browse files
author
Thiago Crepaldi
committed
Remove unnecessary constant outputs from ONNX exported graph
TracingAdapter creates extra outputs (through flatten_to_tuple) to store metadata to rebuild the original data format. This is unnecessary during ONNX export as the original data will never be reconstructed to its original format using Schema.__call__ API. This PR suppresses such extra output constants during torch.onnx.export() execution. Outside this API, the behavior is not changed, ensuring BC. Although not stricly necessary to achieve the same numerical results as PyTorch, when a ONNX model schema is compared to PyTorch's, the diffrent number of outputs (ONNX model will have more outputs than PyTorch) may not only confuse users, but also result in false negative when coding model comparison helpers.
1 parent b352172 commit f108953

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

detectron2/export/flatten.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
import collections
3+
import warnings
34
from dataclasses import dataclass
45
from typing import Callable, List, Optional, Tuple
56
import torch
@@ -17,6 +18,7 @@ class Schema:
1718
A Schema defines how to flatten a possibly hierarchical object into tuple of
1819
primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
1920
21+
The flatten representation can be used as inputs/outputs of PyTorch's tracing.
2022
PyTorch does not support tracing a function that produces rich output
2123
structures (e.g. dict, Instances, Boxes). To trace such a function, we
2224
flatten the rich object into tuple of tensors, and return this tuple of tensors
@@ -254,6 +256,7 @@ def __init__(
254256
inputs = (inputs,)
255257
self.inputs = inputs
256258
self.allow_non_tensor = allow_non_tensor
259+
self._is_in_torch_onnx_export = torch.onnx.is_in_onnx_export()
257260

258261
if inference_func is None:
259262
inference_func = lambda model, *inputs: model(*inputs) # noqa
@@ -307,6 +310,21 @@ def forward(self, *args: torch.Tensor):
307310
"cannot flatten to tensors."
308311
)
309312
else: # schema is valid
313+
# During torch.onnx.export(), extra outputs that `Schema` implementations
314+
# can generate (e.g. `InstancesSchema`) are dropped from final ONNX graph.
315+
# This is OK because ONNX graphs do not need to rebuild original data.
316+
if (
317+
len(flattened_output_tensors) > len(flattened_outputs)
318+
and self.__is_in_torch_onnx_export
319+
):
320+
warnings.warn(
321+
"PyTorch ONNX export (`torch.onnx.export`) detected!"
322+
" To prevent extra outputs in the ONNX graph, the original"
323+
" model output cannot be reconstructed through"
324+
" `adapter.outputs_schema(flattened_outputs)`."
325+
" For results evaluation, use `torch.jit.trace` instead."
326+
)
327+
flattened_outputs = flattened_outputs[: len(outputs)]
310328
if self.outputs_schema is None:
311329
self.outputs_schema = schema
312330
else:
@@ -323,6 +341,14 @@ def _create_wrapper(self, traced_model):
323341
"""
324342

325343
def forward(*args):
344+
if traced_model._is_in_torch_onnx_export:
345+
warnings.warn(
346+
"PyTorch ONNX export (`torch.onnx.export`) detected!"
347+
" To prevent extra outputs in the ONNX graph, the original"
348+
" model output cannot be reconstructed through"
349+
" `adapter.outputs_schema(flattened_outputs)`."
350+
" For results evaluation, use `torch.jit.trace` instead."
351+
)
326352
flattened_inputs, _ = flatten_to_tuple(args)
327353
flattened_outputs = traced_model(*flattened_inputs)
328354
return self.outputs_schema(flattened_outputs)

0 commit comments

Comments
 (0)