Skip to content

Commit 9f7c26b

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Fix training IR bug by changing passes order (#138292)
Inserting runtime_assertions cause gm to have different names but the graph signature was populated earlier. To avoid this kind of errors in the future, I refactored these steps into a helper function. Differential Revision: [D64576251](https://our.internmc.facebook.com/intern/diff/D64576251) Pull Request resolved: #138292 Approved by: https://github.com/avikchaudhuri ghstack dependencies: #138266
1 parent 012ff2a commit 9f7c26b

File tree

3 files changed

+51
-49
lines changed

3 files changed

+51
-49
lines changed

test/export/test_export.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,17 @@ def forward(self, x):
399399

400400
self.assertEqual(counter, 1)
401401

402+
def test_symint_output(self):
403+
class Foo(torch.nn.Module):
404+
def forward(self, x):
405+
z, y = x.size()
406+
return z + y + x[0], z
407+
408+
inputs = (torch.ones(2, 3),)
409+
dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x")
410+
dynamic_shapes = {"x": (dim0_x, dim1_x)}
411+
export(Foo(), inputs, dynamic_shapes=dynamic_shapes)
412+
402413
def test_no_tensor_computation(self):
403414
class Module(torch.nn.Module):
404415
def forward(self, x, y):
@@ -978,7 +989,6 @@ def forward(self, x):
978989
ep_model = export(model, (x,), strict=False).module()
979990
self.assertTrue(torch.allclose(model(x), ep_model(x)))
980991

981-
@testing.expectedFailureTrainingIRToRunDecompNonStrict # TODO(pianpwk): user_output signature
982992
def test_real_tensor_for_max_op(self):
983993
class Foo(torch.nn.Module):
984994
def forward(self, x, y):

torch/_export/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from torch._guards import detect_fake_mode
2626
from torch._subclasses.fake_tensor import FakeTensor
2727
from torch._subclasses.functional_tensor import FunctionalTensor
28+
from torch.fx._utils import first_call_function_nn_module_stack
29+
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
2830

2931

3032
if TYPE_CHECKING:
@@ -533,6 +535,35 @@ def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.No
533535
return [node for node in nodes if node_call_back(node)]
534536

535537

538+
def apply_runtime_assertion_pass(gm, graph_signature):
539+
from torch._export.passes._node_metadata_hook import (
540+
_node_metadata_hook,
541+
_set_node_metadata_hook,
542+
)
543+
from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names
544+
545+
if not torch._dynamo.config.do_not_emit_runtime_asserts:
546+
stack_trace = (
547+
'File "torch/fx/passes/runtime_assert.py", line 24, '
548+
"in insert_deferred_runtime_asserts"
549+
)
550+
with _set_node_metadata_hook(
551+
gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
552+
):
553+
shape_env = _get_shape_env_from_gm(gm)
554+
if shape_env:
555+
insert_deferred_runtime_asserts(
556+
gm,
557+
shape_env,
558+
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
559+
export=True,
560+
)
561+
# update output specs
562+
gm.recompile()
563+
graph_signature.user_outputs = _graph_output_names(gm)
564+
return gm, graph_signature
565+
566+
536567
def nodes_first(
537568
nodes: List[torch.fx.Node], node_call_back=None
538569
) -> Optional[torch.fx.Node]:

torch/export/_trace.py

Lines changed: 9 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
make_fake_inputs,
2929
produce_guards_and_solve_constraints,
3030
)
31-
from torch._export.passes._node_metadata_hook import (
32-
_node_metadata_hook,
33-
_set_node_metadata_hook,
34-
)
3531
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
3632
from torch._export.passes.lift_constants_pass import (
3733
ConstantAttrMap,
@@ -40,9 +36,9 @@
4036
)
4137
from torch._export.utils import (
4238
_collect_param_buffer_metadata,
43-
_get_shape_env_from_gm,
4439
_populate_param_buffer_metadata_to_new_gm,
4540
_update_gm_meta_if_possible,
41+
apply_runtime_assertion_pass,
4642
placeholder_naming_pass,
4743
placeholder_prefixes,
4844
)
@@ -70,7 +66,6 @@
7066
from torch.export._unlift import _check_input_constraints_pre_hook
7167
from torch.export.dynamic_shapes import _check_dynamic_shapes, _combine_args
7268
from torch.export.exported_program import OutputKind
73-
from torch.fx._utils import first_call_function_nn_module_stack
7469
from torch.fx.experimental.proxy_tensor import make_fx
7570
from torch.fx.experimental.symbolic_shapes import (
7671
ConstraintViolationError,
@@ -80,7 +75,6 @@
8075
)
8176
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
8277
from torch.fx.graph_module import _get_attr
83-
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
8478
from torch.utils._pytree import TreeSpec
8579
from torch.utils._sympy.value_ranges import ValueRangeError
8680

@@ -685,26 +679,7 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
685679
# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
686680
# Overwrite output specs afterwards.
687681
flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs))
688-
if not torch._dynamo.config.do_not_emit_runtime_asserts:
689-
stack_trace = (
690-
'File "torch/fx/passes/runtime_assert.py", line 24, '
691-
"in insert_deferred_runtime_asserts"
692-
)
693-
with _set_node_metadata_hook(
694-
gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
695-
):
696-
shape_env = _get_shape_env_from_gm(gm)
697-
if shape_env:
698-
insert_deferred_runtime_asserts(
699-
gm,
700-
shape_env,
701-
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
702-
export=True,
703-
)
704-
705-
# update output specs
706-
gm.recompile()
707-
graph_signature.user_outputs = _graph_output_names(gm)
682+
gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
708683

709684
total_non_user_inputs = (
710685
len(graph_signature.parameters)
@@ -1535,11 +1510,6 @@ def wrapped_fn(*args):
15351510
gm.meta.update(mod.meta)
15361511

15371512
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
1538-
set_missing_meta_vals(gm, flat_args, params_len)
1539-
1540-
export_graph_signature = _convert_to_export_graph_signature(
1541-
graph_signature, gm, _get_non_persistent_buffers(mod)
1542-
)
15431513

15441514
# See comment in _export_to_aten_ir()
15451515
if produce_guards_callback:
@@ -1548,22 +1518,7 @@ def wrapped_fn(*args):
15481518
except (ConstraintViolationError, ValueRangeError) as e:
15491519
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
15501520

1551-
fake_mode = detect_fake_mode(flat_args)
1552-
1553-
if not torch._dynamo.config.do_not_emit_runtime_asserts:
1554-
stack_trace = (
1555-
'File "torch/fx/passes/runtime_assert.py", line 24, '
1556-
"in insert_deferred_runtime_asserts"
1557-
)
1558-
with _set_node_metadata_hook(
1559-
gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace)
1560-
):
1561-
insert_deferred_runtime_asserts(
1562-
gm,
1563-
fake_mode.shape_env,
1564-
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
1565-
export=True,
1566-
)
1521+
gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
15671522

15681523
# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
15691524
for _mod in gm.modules():
@@ -1574,6 +1529,12 @@ def wrapped_fn(*args):
15741529
node.meta.pop("nn_module_stack", None)
15751530
node.meta.pop("stack_trace", None)
15761531

1532+
set_missing_meta_vals(gm, flat_args, params_len)
1533+
1534+
export_graph_signature = _convert_to_export_graph_signature(
1535+
graph_signature, gm, _get_non_persistent_buffers(mod)
1536+
)
1537+
15771538
constants = rewrite_script_object_meta(gm)
15781539
constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))
15791540

0 commit comments

Comments
 (0)