-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxonnx-triagedtriaged by ONNX teamtriaged by ONNX teamrelease notes: onnxtorch.onnx related changes that should show up in the release notestorch.onnx related changes that should show up in the release notestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
functorch.experimental.control_flow.cond is not yet supported by torch.onnx.dynamo_export
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using class method in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
model = CondBranchClassMethod()
input = torch.randn(5)
# exported_program = torch.export.export(model, args=(input,))(input) # works
onnx_program = torch.onnx.dynamo_export(model, input) # Unknown call_function target: condError:
Traceback (most recent call last):
File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1439, in dynamo_export
).export()
^^^^^^^^
File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1182, in export
graph_module = self.options.fx_tracer.generate_fx(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 232, in generate_fx
return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<@beartype(torch.onnx._internal.fx.dynamo_graph_extractor.DynamoExport.pre_export_passes) at 0x7fb951ae0fe0>", line 93, in pre_export_passes
File "/opt/pytorch/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 242, in pre_export_passes
return exporter.common_pre_export_passes(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1495, in common_pre_export_passes
).analyze(infra.levels.ERROR)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 60, in analyze
self.onnxfunction_dispatcher._get_aten_name(
File "<@beartype(torch.onnx._internal.fx.onnxfunction_dispatcher.OnnxFunctionDispatcher._get_aten_name) at 0x7fb951cb0720>", line 54, in _get_aten_name
File "/opt/pytorch/torch/onnx/_internal/fx/onnxfunction_dispatcher.py", line 352, in _get_aten_name
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unknown call_function target: condVersions
pytorch main branch
Metadata
Metadata
Assignees
Labels
module: onnxRelated to torch.onnxRelated to torch.onnxonnx-triagedtriaged by ONNX teamtriaged by ONNX teamrelease notes: onnxtorch.onnx related changes that should show up in the release notestorch.onnx related changes that should show up in the release notestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Type
Projects
Status
Done