Skip to content

Commit fe007e4

Browse files
angelayiAnantGulati
authored andcommitted
[ts_converter] Fix prim::If buffer names (#136648)
Summary: We previously incorrectly handled the following graph, specifically for the node `w.3` in `block0`: ``` graph(%x.1 : Float(3, strides=[1], requires_grad=0, device=cpu), %y.1 : int): %2 : __torch__.___torch_mangle_1.M = prim::CreateObject() %3 : int = prim::Constant[value=20](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:34 %4 : int = prim::Constant[value=10](), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:34 %5 : int = prim::Constant[value=1](), scope: M:: %w.1 : int = prim::GetAttr[name="w"](%2), scope: M:: %7 : int = aten::mul(%w.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:746:25 = prim::SetAttr[name="w"](%2, %7), scope: M:: %h.1 : int = prim::GetAttr[name="h"](%2), scope: M:: %9 : int = aten::mul(%h.1, %3), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:747:25 = prim::SetAttr[name="h"](%2, %9), scope: M:: %10 : bool = aten::gt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:19 %res.37 : Tensor = prim::If(%10), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:749:16 block0(): %w.3 : int = prim::GetAttr[name="w"](%2), scope: M:: %res.1 : Tensor = aten::add(%x.1, %w.3, %5), scope: M:: # <string>:5:9 -> (%res.1) block1(): %h.3 : int = prim::GetAttr[name="h"](%2), scope: M:: %res.3 : Tensor = aten::add(%x.1, %h.3, %5), scope: M:: # <string>:5:9 -> (%res.3) %16 : bool = aten::lt(%y.1, %4), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:19 %res : Tensor = prim::If(%16), scope: M:: # /data/users/angelayi/pytorch/test/export/test_converter.py:754:16 block0(): %w : int = prim::GetAttr[name="w"](%2), scope: M:: %res.15 : Tensor = aten::add(%res.37, %w, %5), scope: M:: # <string>:5:9 -> (%res.15) block1(): %h : int = prim::GetAttr[name="h"](%2), scope: M:: %res.21 : Tensor = aten::add(%res.37, %h, %5), scope: M:: # <string>:5:9 -> (%res.21) return (%res) ``` Test Plan: CI Differential Revision: D63399064 Pull Request resolved: #136648 Approved by: https://github.com/SherlockNoMad
1 parent b18a850 commit fe007e4

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

test/export/test_converter.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,32 @@ def forward(self, x: torch.Tensor):
838838
orig_m(*inp),
839839
)
840840

841+
def test_convert_if_duplicate_attr_names(self):
842+
class M(torch.nn.Module):
843+
def __init__(self) -> None:
844+
super().__init__()
845+
self.w = 1
846+
self.h = 2
847+
848+
def forward(self, x: torch.Tensor, y: int):
849+
self.w = self.w * 10
850+
self.h = self.h * 20
851+
852+
if y > 10:
853+
res = self.w + x
854+
else:
855+
res = self.h + x
856+
857+
if y < 10:
858+
res = self.w + res
859+
else:
860+
res = self.h + res
861+
862+
return res
863+
864+
inp = (torch.ones(3), 5)
865+
self._check_equal_ts_ep_converter(M(), inp, option=["script"])
866+
841867
def test_ts2ep_converter_contains(self):
842868
class MIn(torch.nn.Module):
843869
def forward(self, x: torch.Tensor):

torch/_export/converter.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,9 @@ def construct_fqn(ir, ref_map, name_map):
260260
return ".".join(reversed(name_list))
261261

262262

263-
def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]:
263+
def get_block_to_lifted_attrs(
264+
graph: torch._C.Graph,
265+
) -> Tuple[Dict[torch._C.Block, Set[str]], Dict[str, str]]:
264266
"""
265267
Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
266268
When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
@@ -272,7 +274,8 @@ def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set
272274
of the attributes used in the current block, and the lifted attributes of all its child blocks.
273275
274276
Returns:
275-
A mapping of blocks to a set of FQNs of its lifted attributes.
277+
A mapping of blocks to a set of FQNs of its lifted attributes, and a
278+
mapping of node names to the FQNs of its lifted attributes.
276279
"""
277280

278281
# A map from a block to its expected to be lifted arguments.
@@ -334,7 +337,7 @@ def _map_blocks_to_lifted_attrs(entry):
334337
_dfs_get_attr_dependency(graph)
335338
_map_blocks_to_lifted_attrs(graph)
336339

337-
return blocks_to_lifted_attrs
340+
return blocks_to_lifted_attrs, node_to_attr_name
338341

339342

340343
def get_attribute_fqn_from_ts_node(
@@ -393,22 +396,28 @@ def __init__(
393396
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
394397
name_to_non_tensor_attribute: Dict[str, Any],
395398
name_to_constant: Dict[str, Any],
399+
name_to_attribute_fqn: Dict[str, str],
396400
):
397401
self.ts_graph = ts_graph
402+
# Mapping of parameter FQN to actual parameter value
398403
self.name_to_param = name_to_param
404+
# Mapping of buffer FQN to actual buffer value
399405
self.name_to_buffer = name_to_buffer
400406

401407
self.fx_graph: torch.fx.Graph = torch.fx.Graph()
402408
self.input_specs: List[InputSpec] = []
403409
self.output_specs: List[OutputSpec] = []
404410

411+
# Mapping of TS node name to converted FX node
405412
self.name_to_node: Dict[
406413
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
407414
] = {}
415+
# Mapping of TS node name to constant value (int, str, TorchBind obj,
416+
# tensor constants ...)
408417
self.name_to_constant: Dict[str, Any] = name_to_constant
409418

410419
# Mapping from torchscript node output name to attribute fully qualified name
411-
self.name_to_attribute_fqn: Dict[str, str] = {}
420+
self.name_to_attribute_fqn: Dict[str, str] = name_to_attribute_fqn
412421

413422
# Mapping from fully qualified name to real values or a fx graph node
414423
# During convert, this represents the current value of a non-tensor attribute
@@ -427,6 +436,8 @@ def __init__(
427436

428437
self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
429438

439+
# Mapping of block to list of attributes that need to be lifted for each
440+
# block
430441
self.blocks_to_lifted_attrs = blocks_to_lifted_attrs
431442

432443
# Populate methods for the standard operators.
@@ -467,8 +478,8 @@ def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]):
467478
self.blocks_to_lifted_attrs,
468479
{},
469480
self.name_to_constant,
481+
self.name_to_attribute_fqn,
470482
)
471-
subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn
472483

473484
for block_arg in arguments:
474485
normalized_block_arg_name = normalize_name(block_arg)
@@ -537,6 +548,8 @@ def get_fx_value_by_ir_value(self, value: torch._C.Value):
537548
if isinstance(self.name_to_constant[value_name], torch.ScriptObject):
538549
return self.fx_graph.get_attr(value_name)
539550
return self.name_to_constant[value_name]
551+
elif value_name in self.name_to_attribute_fqn:
552+
return self.get_fx_value_by_fqn(self.name_to_attribute_fqn[value_name])
540553
else:
541554
raise ValueError(f"Input {value_name} not found")
542555

@@ -1325,6 +1338,7 @@ def __init__(
13251338
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
13261339
name_to_non_tensor_attribute: Dict[str, Any],
13271340
name_to_constant: Dict[str, Any],
1341+
name_to_attribute_fqn: Dict[str, str],
13281342
):
13291343
super().__init__(
13301344
ts_graph,
@@ -1333,6 +1347,7 @@ def __init__(
13331347
blocks_to_lifted_attrs,
13341348
name_to_non_tensor_attribute,
13351349
name_to_constant,
1350+
name_to_attribute_fqn,
13361351
)
13371352

13381353
# Data to keep track of unsupported nodes.
@@ -1427,7 +1442,9 @@ def convert(self) -> ExportedProgram:
14271442
)
14281443
log.info("TorchScript graph\n\n%s\n", self.ts_graph)
14291444

1430-
blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
1445+
blocks_to_lifted_attrs, name_to_attribute_fqn = get_block_to_lifted_attrs(
1446+
self.ts_graph
1447+
)
14311448

14321449
graph_converter = TS2FXGraphConverter(
14331450
self.ts_graph,
@@ -1436,6 +1453,7 @@ def convert(self) -> ExportedProgram:
14361453
blocks_to_lifted_attrs,
14371454
self.name_to_non_tensor_attributes,
14381455
self.name_to_constant,
1456+
name_to_attribute_fqn,
14391457
)
14401458
gm = graph_converter.convert()
14411459

@@ -1464,7 +1482,9 @@ def convert(self) -> ExportedProgram:
14641482

14651483
@disable_logging(log)
14661484
def explain(self, print_output=True):
1467-
blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
1485+
blocks_to_lifted_attrs, name_to_attribute_fqn = get_block_to_lifted_attrs(
1486+
self.ts_graph
1487+
)
14681488

14691489
graph_converter = ExplainTS2FXGraphConverter(
14701490
self.ts_graph,
@@ -1473,6 +1493,7 @@ def explain(self, print_output=True):
14731493
blocks_to_lifted_attrs,
14741494
self.name_to_non_tensor_attributes,
14751495
self.name_to_constant,
1496+
name_to_attribute_fqn,
14761497
)
14771498
graph_converter.explain()
14781499
if len(graph_converter.unsupported_node_list) > 0:

0 commit comments

Comments
 (0)