Skip to content

Commit fe87858

Browse files
avikchaudhurifacebook-github-bot
authored andcommitted
preserve signatures with multiple calls + buffer mutations (#138669)
Summary: Pull Request resolved: #138669 Test Plan: modified test Differential Revision: D64806175
1 parent fe458ee commit fe87858

File tree

3 files changed

+178
-31
lines changed

3 files changed

+178
-31
lines changed

test/export/test_export.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6575,25 +6575,34 @@ def forward(self, x):
65756575
m = M()
65766576
eager_result = m(*inp)
65776577

6578-
if not is_retracebility_test(self._testMethodName):
6579-
with self.assertRaisesRegex(
6580-
ValueError,
6581-
r"Found multiple calls of module n that mutate buffer n.buf",
6582-
):
6583-
# Unflattening while preserving signatures is NYI for this case.
6584-
torch.export.unflatten(
6585-
export(M(), inp, preserve_module_call_signature=("n",))
6586-
)
6578+
def test(ep):
6579+
epm = ep.module()
6580+
ufm = torch.export.unflatten(ep)
65876581

6588-
ep = export(M(), inp)
6589-
epm = ep.module()
6590-
ufm = torch.export.unflatten(ep)
6582+
exported_result = epm(*inp)
6583+
self.assertTrue(torch.allclose(exported_result, eager_result))
65916584

6592-
exported_result = epm(*inp)
6593-
self.assertTrue(torch.allclose(exported_result, eager_result))
6585+
unflattened_result = ufm(*inp)
6586+
self.assertTrue(torch.allclose(unflattened_result, eager_result))
65946587

6595-
unflattened_result = ufm(*inp)
6596-
self.assertTrue(torch.allclose(unflattened_result, eager_result))
6588+
if not is_retracebility_test(self._testMethodName):
6589+
test(export(M(), inp, preserve_module_call_signature=("n",)))
6590+
# running decompositions again should work for all IRs
6591+
ep = export(M(), inp, preserve_module_call_signature=("n",))
6592+
test(ep.run_decompositions({}))
6593+
if is_training_ir_test(self._testMethodName):
6594+
# since we run decompositions by default when testing training IR,
6595+
# also test training IR without running decompositions
6596+
strict = not is_non_strict_test(self._testMethodName)
6597+
ept = torch.export.export_for_training(
6598+
M(),
6599+
inp,
6600+
strict=strict,
6601+
preserve_module_call_signature=("n",),
6602+
)
6603+
test(ept)
6604+
6605+
test(export(M(), inp))
65976606

65986607
def test_unflatten_multiple_graphs_shared_submodule(self):
65996608
class N(torch.nn.Module):

torch/export/exported_program.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,30 @@ def _common_getitem_elimination_pass(
643643
node_id[node] = node.name
644644

645645

646+
def _get_updated_module_call_graph(
647+
gm: torch.fx.GraphModule,
648+
old_module_call_graph: List[ModuleCallEntry],
649+
):
650+
new_module_call_graph = copy.deepcopy(old_module_call_graph)
651+
652+
# use node-level provenance metadata to create a map
653+
# from old node names to new node names
654+
provenance: Dict[str, str] = {}
655+
for node in gm.graph.nodes:
656+
if history := node.meta.get("from_node", []):
657+
provenance[history[-1][0]] = node.name
658+
659+
# map old names to new names in module call signatures
660+
for entry in new_module_call_graph:
661+
signature = entry.signature
662+
if signature is None:
663+
continue
664+
for x in [*signature.inputs, *signature.outputs]:
665+
x.name = provenance.get(x.name, x.name)
666+
667+
return new_module_call_graph
668+
669+
646670
def _decompose_exported_program(
647671
ep,
648672
*,
@@ -657,6 +681,15 @@ def _decompose_exported_program(
657681
joint_loss_index=joint_loss_index,
658682
)
659683

684+
# The signatures of ep.module_call_graph refer to input / output nodes of
685+
# the original graph module. However, the new graph module may have
686+
# new nodes due to decompositions. So we need to update these signatures
687+
# in the decomposed exported program's module_call_graph.
688+
new_module_call_graph = _get_updated_module_call_graph(
689+
gm,
690+
ep.module_call_graph,
691+
)
692+
660693
# TODO unfortunately preserving graph-level metadata is not
661694
# working well with aot_export. So we manually copy it.
662695
# (The node-level meta is addressed above.)
@@ -673,7 +706,7 @@ def _decompose_exported_program(
673706
graph_signature=new_graph_signature,
674707
state_dict=ep.state_dict,
675708
range_constraints=new_range_constraints,
676-
module_call_graph=copy.deepcopy(ep.module_call_graph),
709+
module_call_graph=new_module_call_graph,
677710
example_inputs=ep.example_inputs,
678711
constants=ep.constants,
679712
)

torch/export/unflatten.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.export.exported_program import (
2020
ConstantArgument,
2121
ExportedProgram,
22+
ExportGraphSignature,
2223
InputKind,
2324
ModuleCallSignature,
2425
SymIntArgument,
@@ -219,19 +220,6 @@ def __init__(
219220
if export_module.graph_signature.backward_signature is not None:
220221
raise ValueError("Unflattening on JointExportModule NYI")
221222

222-
preserved_module_targets_with_multiple_calls = [
223-
entry.fqn.split("@")[0]
224-
for entry in export_module.module_call_graph
225-
if "@" in entry.fqn
226-
]
227-
for buf in export_module.graph_signature.buffers_to_mutate.values():
228-
for fqn in preserved_module_targets_with_multiple_calls:
229-
if buf.startswith(fqn + "."):
230-
raise ValueError(
231-
f"Found multiple calls of module {fqn} that mutate buffer {buf}. "
232-
"Unflattening while preserving signatures is NYI for this case."
233-
)
234-
235223
fqn_list = [entry.fqn for entry in export_module.module_call_graph]
236224
assert fqn_list[0] == ""
237225
export_graph = deepcopy(export_module.graph)
@@ -245,7 +233,15 @@ def __init__(
245233
self._run_with_interpeter = RUN_WITH_INTERPRETER
246234

247235
_inplace_buffer_mutations(export_graph, self.graph_signature)
236+
237+
self.ivals = _IVals()
238+
# record any intermediate value x that is used, with the modules that used it,
239+
# and generate instructions to read the corresponding attribute
248240
seen_modules = _outline_submodules(export_graph, self)
241+
# for each read intermediate value x, find the module that created it,
242+
# and generate instructions to update the corresponding attribute;
243+
# finally, initialize all these attributes
244+
self.ivals.create(seen_modules.values())
249245

250246
self.range_constraints = export_module.range_constraints
251247
self.equality_constraints: List = []
@@ -584,7 +580,10 @@ def unflatten(
584580
return UnflattenedModule(module, flat_args_adapter)
585581

586582

587-
def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None:
583+
def _inplace_buffer_mutations(
584+
graph: torch.fx.Graph,
585+
graph_signature: ExportGraphSignature,
586+
) -> None:
588587
"""Transform buffer mutations from their functionalized form into a copy_
589588
node in the graph.
590589
@@ -784,8 +783,10 @@ def __init__(
784783

785784
if module is not None:
786785
self.module = module
786+
self.ivals = module.ivals if hasattr(module, "ivals") else {}
787787
else:
788788
self.module = InterpreterModule(torch.fx.Graph())
789+
self.ivals = parent.ivals
789790

790791
self.graph = self.module.graph
791792

@@ -948,6 +949,10 @@ def remap_input(self, x):
948949
# if module call signature needs to be preserved
949950
self.copy_sym_call_function(x)
950951
return self.node_map[x]
952+
elif self.module_call_graph.get(self.fqn) is not None:
953+
# x is an ival that is not in placeholders, so create a
954+
# get_attr node corresponding to attribute __ival__x
955+
return self.ivals.read(self.fqn, self.graph, x)
951956
else:
952957
raise RuntimeError(
953958
f"Could not run remap_input() on op type: {x.op} for node {x}"
@@ -1198,6 +1203,106 @@ def _reorder_submodules(
11981203
parent.register_module(name, child)
11991204

12001205

1206+
class _IVals:
1207+
"""
1208+
Collect the intermediate values of buffer mutations in a graph,
1209+
along with the module call fqns that create and use them. Later,
1210+
in each fqn associated with an intermediate value we will install
1211+
a corresponding attribute, so that it can be updated and read.
1212+
1213+
Example: in the following graph, suppose that buf_in and buf_out
1214+
are the input and output values of a buffer.
1215+
1216+
buf_in = placeholder()
1217+
...
1218+
ival1 = f0(buf_in, ...) # inside self.n0(...)
1219+
...
1220+
ival2 = f1(ival1, ...) # inside self.n1(...)
1221+
...
1222+
buf_out = f2(ival2, ...) # inside self.n2(...)
1223+
return buf_out, ...
1224+
1225+
Here ival1 and ival2 are intermediate values created inside
1226+
calls to n0 and n1 respectively, and used inside calls to
1227+
n1 and n2 respectively.
1228+
1229+
Thus our analysis will produce {ival1: {n0, n1}, ival2: {n1, n2}}.
1230+
"""
1231+
1232+
def __init__(self):
1233+
# ival node name -> set of fqns that create and use it
1234+
self.fqns = defaultdict(set)
1235+
# ival node name -> tensor storage for corresponding attribute
1236+
self.storage = {}
1237+
1238+
def read(self, fqn, graph, node):
1239+
"""
1240+
Read attribute corresponding to a given intermediate value.
1241+
"""
1242+
# to read ival x, get attribute __ival__x
1243+
with graph.inserting_before(None):
1244+
ival_node = graph.get_attr("__ival__" + node.name, type_expr=node.type)
1245+
ival_node.meta = copy.copy(node.meta)
1246+
1247+
if node.name not in self.storage:
1248+
# create empty tensor matching fake, using a cache
1249+
# to ensure the same tensor is returned per ival_name
1250+
fake = node.meta["val"]
1251+
self.storage[node.name] = torch.empty(fake.shape, dtype=fake.dtype)
1252+
self.fqns[node.name].add(fqn)
1253+
1254+
return ival_node
1255+
1256+
def update(self, fqn, graph, node):
1257+
"""
1258+
Update attribute corresponding to a given intermediate value.
1259+
"""
1260+
self.fqns[node.name].add(fqn)
1261+
1262+
# to update ival x, get attribute __ival__x and copy x to __ival__x
1263+
with graph.inserting_after(node):
1264+
ival_node = graph.get_attr("__ival__" + node.name, type_expr=node.type)
1265+
ival_node.meta = copy.copy(node.meta)
1266+
with graph.inserting_after(ival_node):
1267+
new_ival_node = graph.create_node(
1268+
"call_function", torch.ops.aten.copy_, (ival_node, node)
1269+
)
1270+
new_ival_node.meta = copy.copy(node.meta)
1271+
1272+
def create(self, partitions):
1273+
"""
1274+
Update attributes corresponding to intermediate values that were read.
1275+
Finally, initialize attributes in all modules that read or update
1276+
corresponding intermediate values.
1277+
"""
1278+
1279+
entries = []
1280+
for shared_submodules in partitions:
1281+
for entry in shared_submodules:
1282+
entries.append(entry)
1283+
graph = entry.module.graph
1284+
for node in graph.nodes:
1285+
if node.name in self.storage:
1286+
self.update(entry.fqn, graph, node)
1287+
1288+
# fqn -> list of ival node names read or updated through it
1289+
ivals = defaultdict(list)
1290+
for name, fqns in self.fqns.items():
1291+
for fqn in fqns:
1292+
ivals[fqn].append(name)
1293+
1294+
for entry in entries:
1295+
for name in ivals[entry.fqn]:
1296+
ival_name = f"__ival__{name}"
1297+
# for a ival named x created in module call m,
1298+
# create attribute m.__ival__x, initially empty
1299+
setattr(
1300+
entry.module,
1301+
ival_name,
1302+
self.storage[name],
1303+
)
1304+
1305+
12011306
def _deduplicate_modules(partitions):
12021307
for shared_submodules in partitions:
12031308
for i, entry in enumerate(shared_submodules):

0 commit comments

Comments
 (0)