1919from torch .export .exported_program import (
2020 ConstantArgument ,
2121 ExportedProgram ,
22+ ExportGraphSignature ,
2223 InputKind ,
2324 ModuleCallSignature ,
2425 SymIntArgument ,
@@ -219,19 +220,6 @@ def __init__(
219220 if export_module .graph_signature .backward_signature is not None :
220221 raise ValueError ("Unflattening on JointExportModule NYI" )
221222
222- preserved_module_targets_with_multiple_calls = [
223- entry .fqn .split ("@" )[0 ]
224- for entry in export_module .module_call_graph
225- if "@" in entry .fqn
226- ]
227- for buf in export_module .graph_signature .buffers_to_mutate .values ():
228- for fqn in preserved_module_targets_with_multiple_calls :
229- if buf .startswith (fqn + "." ):
230- raise ValueError (
231- f"Found multiple calls of module { fqn } that mutate buffer { buf } . "
232- "Unflattening while preserving signatures is NYI for this case."
233- )
234-
235223 fqn_list = [entry .fqn for entry in export_module .module_call_graph ]
236224 assert fqn_list [0 ] == ""
237225 export_graph = deepcopy (export_module .graph )
@@ -245,7 +233,15 @@ def __init__(
245233 self ._run_with_interpeter = RUN_WITH_INTERPRETER
246234
247235 _inplace_buffer_mutations (export_graph , self .graph_signature )
236+
237+ self .ivals = _IVals ()
238+ # record any intermediate value x that is used, with the modules that used it,
239+ # and generate instructions to read the corresponding attribute
248240 seen_modules = _outline_submodules (export_graph , self )
241+ # for each read intermediate value x, find the module that created it,
242+ # and generate instructions to update the corresponding attribute;
243+ # finally, initialize all these attributes
244+ self .ivals .create (seen_modules .values ())
249245
250246 self .range_constraints = export_module .range_constraints
251247 self .equality_constraints : List = []
@@ -584,7 +580,10 @@ def unflatten(
584580 return UnflattenedModule (module , flat_args_adapter )
585581
586582
587- def _inplace_buffer_mutations (graph : torch .fx .Graph , graph_signature ) -> None :
583+ def _inplace_buffer_mutations (
584+ graph : torch .fx .Graph ,
585+ graph_signature : ExportGraphSignature ,
586+ ) -> None :
588587 """Transform buffer mutations from their functionalized form into a copy_
589588 node in the graph.
590589
@@ -784,8 +783,10 @@ def __init__(
784783
785784 if module is not None :
786785 self .module = module
786+ self .ivals = module .ivals if hasattr (module , "ivals" ) else {}
787787 else :
788788 self .module = InterpreterModule (torch .fx .Graph ())
789+ self .ivals = parent .ivals
789790
790791 self .graph = self .module .graph
791792
@@ -948,6 +949,10 @@ def remap_input(self, x):
948949 # if module call signature needs to be preserved
949950 self .copy_sym_call_function (x )
950951 return self .node_map [x ]
952+ elif self .module_call_graph .get (self .fqn ) is not None :
953+ # x is an ival that is not in placeholders, so create a
954+ # get_attr node corresponding to attribute __ival__x
955+ return self .ivals .read (self .fqn , self .graph , x )
951956 else :
952957 raise RuntimeError (
953958 f"Could not run remap_input() on op type: { x .op } for node { x } "
@@ -1198,6 +1203,106 @@ def _reorder_submodules(
11981203 parent .register_module (name , child )
11991204
12001205
1206+ class _IVals :
1207+ """
1208+ Collect the intermediate values of buffer mutations in a graph,
1209+ along with the module call fqns that create and use them. Later,
1210+ in each fqn associated with an intermediate value we will install
1211+ a corresponding attribute, so that it can be updated and read.
1212+
1213+ Example: in the following graph, suppose that buf_in and buf_out
1214+ are the input and output values of a buffer.
1215+
1216+ buf_in = placeholder()
1217+ ...
1218+ ival1 = f0(buf_in, ...) # inside self.n0(...)
1219+ ...
1220+ ival2 = f1(ival1, ...) # inside self.n1(...)
1221+ ...
1222+ buf_out = f2(ival2, ...) # inside self.n2(...)
1223+ return buf_out, ...
1224+
1225+ Here ival1 and ival2 are intermediate values created inside
1226+ calls to n0 and n1 respectively, and used inside calls to
1227+ n1 and n2 respectively.
1228+
1229+ Thus our analysis will produce {ival1: {n0, n1}, ival2: {n1, n2}}.
1230+ """
1231+
1232+ def __init__ (self ):
1233+ # ival node name -> set of fqns that create and use it
1234+ self .fqns = defaultdict (set )
1235+ # ival node name -> tensor storage for corresponding attribute
1236+ self .storage = {}
1237+
1238+ def read (self , fqn , graph , node ):
1239+ """
1240+ Read attribute corresponding to a given intermediate value.
1241+ """
1242+ # to read ival x, get attribute __ival__x
1243+ with graph .inserting_before (None ):
1244+ ival_node = graph .get_attr ("__ival__" + node .name , type_expr = node .type )
1245+ ival_node .meta = copy .copy (node .meta )
1246+
1247+ if node .name not in self .storage :
1248+ # create empty tensor matching fake, using a cache
1249+ # to ensure the same tensor is returned per ival_name
1250+ fake = node .meta ["val" ]
1251+ self .storage [node .name ] = torch .empty (fake .shape , dtype = fake .dtype )
1252+ self .fqns [node .name ].add (fqn )
1253+
1254+ return ival_node
1255+
1256+ def update (self , fqn , graph , node ):
1257+ """
1258+ Update attribute corresponding to a given intermediate value.
1259+ """
1260+ self .fqns [node .name ].add (fqn )
1261+
1262+ # to update ival x, get attribute __ival__x and copy x to __ival__x
1263+ with graph .inserting_after (node ):
1264+ ival_node = graph .get_attr ("__ival__" + node .name , type_expr = node .type )
1265+ ival_node .meta = copy .copy (node .meta )
1266+ with graph .inserting_after (ival_node ):
1267+ new_ival_node = graph .create_node (
1268+ "call_function" , torch .ops .aten .copy_ , (ival_node , node )
1269+ )
1270+ new_ival_node .meta = copy .copy (node .meta )
1271+
1272+ def create (self , partitions ):
1273+ """
1274+ Update attributes corresponding to intermediate values that were read.
1275+ Finally, initialize attributes in all modules that read or update
1276+ corresponding intermediate values.
1277+ """
1278+
1279+ entries = []
1280+ for shared_submodules in partitions :
1281+ for entry in shared_submodules :
1282+ entries .append (entry )
1283+ graph = entry .module .graph
1284+ for node in graph .nodes :
1285+ if node .name in self .storage :
1286+ self .update (entry .fqn , graph , node )
1287+
1288+ # fqn -> list of ival node names read or updated through it
1289+ ivals = defaultdict (list )
1290+ for name , fqns in self .fqns .items ():
1291+ for fqn in fqns :
1292+ ivals [fqn ].append (name )
1293+
1294+ for entry in entries :
1295+ for name in ivals [entry .fqn ]:
1296+ ival_name = f"__ival__{ name } "
1297+ # for a ival named x created in module call m,
1298+ # create attribute m.__ival__x, initially empty
1299+ setattr (
1300+ entry .module ,
1301+ ival_name ,
1302+ self .storage [name ],
1303+ )
1304+
1305+
12011306def _deduplicate_modules (partitions ):
12021307 for shared_submodules in partitions :
12031308 for i , entry in enumerate (shared_submodules ):
0 commit comments