2828 make_fake_inputs ,
2929 produce_guards_and_solve_constraints ,
3030)
31- from torch ._export .passes ._node_metadata_hook import (
32- _node_metadata_hook ,
33- _set_node_metadata_hook ,
34- )
3531from torch ._export .passes .collect_tracepoints_pass import CollectTracepointsPass
3632from torch ._export .passes .lift_constants_pass import (
3733 ConstantAttrMap ,
4036)
4137from torch ._export .utils import (
4238 _collect_param_buffer_metadata ,
43- _get_shape_env_from_gm ,
4439 _populate_param_buffer_metadata_to_new_gm ,
4540 _update_gm_meta_if_possible ,
41+ apply_runtime_assertion_pass ,
4642 placeholder_naming_pass ,
4743 placeholder_prefixes ,
4844)
7066from torch .export ._unlift import _check_input_constraints_pre_hook
7167from torch .export .dynamic_shapes import _check_dynamic_shapes , _combine_args
7268from torch .export .exported_program import OutputKind
73- from torch .fx ._utils import first_call_function_nn_module_stack
7469from torch .fx .experimental .proxy_tensor import make_fx
7570from torch .fx .experimental .symbolic_shapes import (
7671 ConstraintViolationError ,
8075)
8176from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
8277from torch .fx .graph_module import _get_attr
83- from torch .fx .passes .runtime_assert import insert_deferred_runtime_asserts
8478from torch .utils ._pytree import TreeSpec
8579from torch .utils ._sympy .value_ranges import ValueRangeError
8680
@@ -685,26 +679,7 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
685679 # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
686680 # Overwrite output specs afterwards.
687681 flat_fake_args = pytree .tree_leaves ((fake_args , fake_kwargs ))
688- if not torch ._dynamo .config .do_not_emit_runtime_asserts :
689- stack_trace = (
690- 'File "torch/fx/passes/runtime_assert.py", line 24, '
691- "in insert_deferred_runtime_asserts"
692- )
693- with _set_node_metadata_hook (
694- gm , functools .partial (_node_metadata_hook , stack_trace = stack_trace )
695- ):
696- shape_env = _get_shape_env_from_gm (gm )
697- if shape_env :
698- insert_deferred_runtime_asserts (
699- gm ,
700- shape_env ,
701- f"exported program: { first_call_function_nn_module_stack (gm .graph )} " ,
702- export = True ,
703- )
704-
705- # update output specs
706- gm .recompile ()
707- graph_signature .user_outputs = _graph_output_names (gm )
682+ gm , graph_signature = apply_runtime_assertion_pass (gm , graph_signature )
708683
709684 total_non_user_inputs = (
710685 len (graph_signature .parameters )
@@ -1535,11 +1510,6 @@ def wrapped_fn(*args):
15351510 gm .meta .update (mod .meta )
15361511
15371512 flat_args = pytree .tree_leaves ((fake_args , fake_kwargs ))
1538- set_missing_meta_vals (gm , flat_args , params_len )
1539-
1540- export_graph_signature = _convert_to_export_graph_signature (
1541- graph_signature , gm , _get_non_persistent_buffers (mod )
1542- )
15431513
15441514 # See comment in _export_to_aten_ir()
15451515 if produce_guards_callback :
@@ -1548,22 +1518,7 @@ def wrapped_fn(*args):
15481518 except (ConstraintViolationError , ValueRangeError ) as e :
15491519 raise UserError (UserErrorType .CONSTRAINT_VIOLATION , str (e )) # noqa: B904
15501520
1551- fake_mode = detect_fake_mode (flat_args )
1552-
1553- if not torch ._dynamo .config .do_not_emit_runtime_asserts :
1554- stack_trace = (
1555- 'File "torch/fx/passes/runtime_assert.py", line 24, '
1556- "in insert_deferred_runtime_asserts"
1557- )
1558- with _set_node_metadata_hook (
1559- gm , functools .partial (_node_metadata_hook , stack_trace = stack_trace )
1560- ):
1561- insert_deferred_runtime_asserts (
1562- gm ,
1563- fake_mode .shape_env ,
1564- f"exported program: { first_call_function_nn_module_stack (gm .graph )} " ,
1565- export = True ,
1566- )
1521+ gm , graph_signature = apply_runtime_assertion_pass (gm , graph_signature )
15671522
15681523 # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
15691524 for _mod in gm .modules ():
@@ -1574,6 +1529,12 @@ def wrapped_fn(*args):
15741529 node .meta .pop ("nn_module_stack" , None )
15751530 node .meta .pop ("stack_trace" , None )
15761531
1532+ set_missing_meta_vals (gm , flat_args , params_len )
1533+
1534+ export_graph_signature = _convert_to_export_graph_signature (
1535+ graph_signature , gm , _get_non_persistent_buffers (mod )
1536+ )
1537+
15771538 constants = rewrite_script_object_meta (gm )
15781539 constants .update (lift_constants_pass (gm , export_graph_signature , constant_attrs ))
15791540
0 commit comments