Skip to content

[ONNX] dynamo: support conditional op cond  #117655

@thiagocrepaldi

Description

@thiagocrepaldi

🐛 Describe the bug

functorch.experimental.control_flow.cond is not yet supported by torch.onnx.dynamo_export

import torch

from functorch.experimental.control_flow import cond


class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)


class CondBranchClassMethod(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables


    This example demonstrates using class method in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

model = CondBranchClassMethod()
input = torch.randn(5)

# exported_program = torch.export.export(model, args=(input,))(input)  # works
onnx_program = torch.onnx.dynamo_export(model, input)  # Unknown call_function target: cond

Error:

Traceback (most recent call last):
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1439, in dynamo_export
    ).export()
      ^^^^^^^^
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1182, in export
    graph_module = self.options.fx_tracer.generate_fx(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 232, in generate_fx
    return self.pre_export_passes(options, model, graph_module, updated_model_args)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<@beartype(torch.onnx._internal.fx.dynamo_graph_extractor.DynamoExport.pre_export_passes) at 0x7fb951ae0fe0>", line 93, in pre_export_passes
  File "/opt/pytorch/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 242, in pre_export_passes
    return exporter.common_pre_export_passes(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/onnx/_internal/exporter.py", line 1495, in common_pre_export_passes
    ).analyze(infra.levels.ERROR)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 60, in analyze
    self.onnxfunction_dispatcher._get_aten_name(
  File "<@beartype(torch.onnx._internal.fx.onnxfunction_dispatcher.OnnxFunctionDispatcher._get_aten_name) at 0x7fb951cb0720>", line 54, in _get_aten_name
  File "/opt/pytorch/torch/onnx/_internal/fx/onnxfunction_dispatcher.py", line 352, in _get_aten_name
    raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unknown call_function target: cond

Versions

pytorch main branch

Metadata

Metadata

Labels

module: onnxRelated to torch.onnxonnx-triagedtriaged by ONNX teamrelease notes: onnxtorch.onnx related changes that should show up in the release notestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions