Skip to content

Commit 23fffb5

Browse files
masnesralpytorchmergebot
authored andcommitted
Use OrderedSet in _functorch/partitioners (#146102)
In an attempt to make partitioning more deterministic, change all sets in partitioners.py to OrderedSets. Note that this change does not fix the non-determinism we're seeing in the internal model. But let's at least eliminate this potential source of non-determinism before investigating any changes to the mincut approach? Pull Request resolved: #146102 Approved by: https://github.com/oulgen
1 parent 53759cc commit 23fffb5

File tree

2 files changed

+57
-48
lines changed

2 files changed

+57
-48
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,6 +1703,7 @@ command = [
17031703
]
17041704
include_patterns = [
17051705
"torch/_inductor/**/*.py",
1706+
"torch/_functorch/partitioners.py",
17061707
]
17071708
is_formatter = true
17081709

torch/_functorch/partitioners.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
is_symbol_binding_fx_node,
2626
)
2727
from torch.fx.passes import graph_drawer
28+
from torch.utils._ordered_set import OrderedSet
2829
from torch.utils.checkpoint import CheckpointPolicy
2930

3031
from . import config
@@ -55,11 +56,11 @@
5556
class OpTypes:
5657
"""Class for keeping track of different operator categories"""
5758

58-
fusible_ops: set[Callable]
59-
compute_intensive_ops: set[Callable]
60-
random_ops: set[Callable]
61-
view_ops: set[Callable]
62-
recomputable_ops: set[Callable]
59+
fusible_ops: OrderedSet[Callable]
60+
compute_intensive_ops: OrderedSet[Callable]
61+
random_ops: OrderedSet[Callable]
62+
view_ops: OrderedSet[Callable]
63+
recomputable_ops: OrderedSet[Callable]
6364

6465
def is_fusible(self, node: fx.Node):
6566
return get_aten_target(node) in self.fusible_ops
@@ -82,9 +83,9 @@ class NodeInfo:
8283
# Be careful about iterating over these explicitly, as their order may not
8384
# be deterministic
8485
inputs: list[fx.Node]
85-
_required_fw_nodes: set[fx.Node]
86-
required_bw_nodes: set[fx.Node]
87-
unclaimed_nodes: set[fx.Node]
86+
_required_fw_nodes: OrderedSet[fx.Node]
87+
required_bw_nodes: OrderedSet[fx.Node]
88+
unclaimed_nodes: OrderedSet[fx.Node]
8889
fw_order: dict[fx.Node, int]
8990

9091
@functools.cached_property
@@ -326,7 +327,7 @@ def _extract_fwd_bwd_modules(
326327
# we propagate all symbols which are referenced by backwards inputs.
327328
# These are not directly used in the graph but are required for downstream
328329
# sizevar assignment
329-
saved_symbols: set[sympy.Symbol] = set()
330+
saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
330331
saved_sym_nodes_binding = []
331332
saved_sym_nodes_derived = []
332333

@@ -426,9 +427,9 @@ def default_partition(
426427
forward_only_graph = _extract_graph_with_inputs_outputs(
427428
joint_module.graph, inputs, fwd_outputs, "forward"
428429
)
429-
forward_node_names = {
430+
forward_node_names = OrderedSet(
430431
node.name for node in forward_only_graph.nodes if node.op != "output"
431-
}
432+
)
432433
saved_values = []
433434
saved_sym_nodes = []
434435

@@ -580,7 +581,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
580581

581582
def insert_node_in_graph(node):
582583
cur_nodes = [node]
583-
insertable_nodes = set()
584+
insertable_nodes: OrderedSet[fx.Node] = OrderedSet()
584585
while len(cur_nodes) > 0:
585586
node = cur_nodes.pop()
586587
if node in insertable_nodes or node in env:
@@ -817,19 +818,21 @@ def solve_min_cut(
817818
joint_graph: fx.Graph,
818819
node_info: NodeInfo,
819820
min_cut_options: MinCutOptions,
820-
dont_ban=None,
821+
dont_ban: Optional[OrderedSet[fx.Node]] = None,
821822
):
822823
if dont_ban is None:
823-
dont_ban = set()
824+
dont_ban = OrderedSet()
824825
op_types = get_default_op_list()
825826

826827
if AOT_PARTITIONER_DEBUG:
827-
joint_module_ops = {
828+
joint_module_ops = OrderedSet(
828829
str(node.target._overloadpacket)
829830
for node in joint_graph.nodes
830831
if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
831-
}
832-
ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops}
832+
)
833+
ops_ignored = joint_module_ops - OrderedSet(
834+
str(i) for i in op_types.recomputable_ops
835+
)
833836
log.info("Ops banned from re-materialization: %s", ops_ignored)
834837

835838
def can_fuse_into_auto_functionalized(a, b):
@@ -888,7 +891,7 @@ def is_fusible(a, b):
888891
def is_materialized_backwards(node):
889892
if op_types.is_view(node):
890893
return False
891-
cur_nodes = {node}
894+
cur_nodes = OrderedSet([node])
892895
while len(cur_nodes) > 0:
893896
cur = cur_nodes.pop()
894897
for user in cur.users:
@@ -981,7 +984,7 @@ def get_node_weight(node) -> float:
981984
return mem_sz * 2
982985

983986
nx_graph = nx.DiGraph()
984-
banned_nodes = set()
987+
banned_nodes: OrderedSet[fx.Node] = OrderedSet()
985988

986989
def ban_recomputation_if_allowed(node):
987990
if op_types.is_view(node):
@@ -1091,12 +1094,13 @@ def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
10911094
if node_info.is_required_fw(user):
10921095
if node_info.get_fw_order(user) > max_range:
10931096
continue
1094-
val = (node_info.get_fw_order(user), user, is_fusible(node, user))
1097+
val: tuple[int, fx.Node, bool] = (
1098+
node_info.get_fw_order(user),
1099+
user,
1100+
is_fusible(node, user),
1101+
)
10951102
if val not in sorted_nodes:
1096-
heapq.heappush(
1097-
sorted_nodes,
1098-
val,
1099-
)
1103+
heapq.heappush(sorted_nodes, val)
11001104
return max_range
11011105

11021106
if min_cut_options.ban_if_used_far_apart:
@@ -1141,11 +1145,13 @@ def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
11411145
# Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
11421146

11431147
if min_cut_options.ban_if_long_fusible_chains:
1144-
visited = set()
1148+
visited: OrderedSet[fx.Node] = OrderedSet()
11451149
for start_node in joint_graph.nodes:
11461150
if not node_info.is_required_fw(start_node):
11471151
continue
1148-
fusible = [(node_info.get_fw_order(start_node), start_node)]
1152+
fusible: list[tuple[int, fx.Node]] = [
1153+
(node_info.get_fw_order(start_node), start_node)
1154+
]
11491155
start_order = node_info.get_fw_order(start_node)
11501156
while len(fusible) > 0:
11511157
_, cur = heapq.heappop(fusible)
@@ -1184,11 +1190,11 @@ def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
11841190
raise
11851191

11861192
reachable, non_reachable = partition
1187-
cutset: set[tuple[str, str]] = set()
1193+
cutset: OrderedSet[tuple[str, str]] = OrderedSet()
11881194
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
11891195
cutset.update((u, v) for v in nbrs if v in non_reachable)
11901196

1191-
cut_nodes = set()
1197+
cut_nodes: OrderedSet[str] = OrderedSet()
11921198
for node_in, node_out in cutset:
11931199
assert node_in[:-3] == node_out[:-4]
11941200
node_name = node_in[:-3]
@@ -1358,9 +1364,9 @@ def get_default_op_list() -> OpTypes:
13581364
]
13591365

13601366
default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
1361-
recomputable_ops = set(default_recomputable_ops)
1367+
recomputable_ops = OrderedSet(default_recomputable_ops)
13621368

1363-
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
1369+
random_ops = OrderedSet([aten.native_dropout, aten.rand_like, aten.randn_like])
13641370
compute_intensive_ops = [
13651371
aten.mm,
13661372
aten.convolution,
@@ -1375,13 +1381,13 @@ def get_default_op_list() -> OpTypes:
13751381
aten._scaled_mm,
13761382
] # noqa: E501,B950
13771383

1378-
fusible_ops = recomputable_ops | set(random_ops)
1384+
fusible_ops = recomputable_ops | random_ops
13791385
return OpTypes(
1380-
set(fusible_ops),
1381-
set(compute_intensive_ops),
1382-
set(random_ops),
1383-
set(view_ops),
1384-
set(recomputable_ops),
1386+
fusible_ops,
1387+
OrderedSet(compute_intensive_ops),
1388+
random_ops,
1389+
OrderedSet(view_ops),
1390+
recomputable_ops,
13851391
)
13861392

13871393

@@ -1567,9 +1573,11 @@ def get_mem_ratio(activations: list[fx.Node]):
15671573

15681574
from torch._inductor.fx_utils import get_node_storage
15691575

1570-
input_storages = {get_node_storage(node) for node in node_info.inputs}
1576+
input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs)
15711577

1572-
def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]:
1578+
def get_recomputable_banned_nodes(
1579+
banned_nodes: OrderedSet[fx.Node],
1580+
) -> list[fx.Node]:
15731581
return [
15741582
i
15751583
for i in banned_nodes
@@ -1653,7 +1661,7 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
16531661
payload_fn=lambda: knapsack_summary,
16541662
)
16551663
log.info(knapsack_summary)
1656-
dont_ban = set()
1664+
dont_ban: OrderedSet[fx.Node] = OrderedSet()
16571665
for idx in recomputable_node_idxs:
16581666
# if idx in all_recomputable_banned_nodes:
16591667
try:
@@ -1776,7 +1784,7 @@ def min_cut_rematerialization_partition(
17761784

17771785
def classify_nodes(joint_module):
17781786
name_to_node = get_name_to_node(joint_module.graph)
1779-
required_bw_nodes = set()
1787+
required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
17801788
for node in joint_module.graph.nodes:
17811789
if node.op == "placeholder" and "tangents" in node.target:
17821790
required_bw_nodes.add(node)
@@ -1800,16 +1808,16 @@ def classify_nodes(joint_module):
18001808
forward_only_graph = _extract_graph_with_inputs_outputs(
18011809
joint_module.graph, inputs, fwd_outputs, "forward"
18021810
)
1803-
required_fw_nodes: set[fx.Node] = {
1811+
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
18041812
name_to_node[node.name]
18051813
for node in forward_only_graph.nodes
18061814
if node.op != "output"
1807-
}
1808-
unclaimed_nodes = {
1815+
)
1816+
unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
18091817
node
18101818
for node in joint_module.graph.nodes
18111819
if node not in required_fw_nodes and node not in required_bw_nodes
1812-
}
1820+
)
18131821
fw_cnt = 0
18141822
fw_order = {}
18151823
for node in joint_module.graph.nodes:
@@ -1879,12 +1887,12 @@ def classify_nodes(joint_module):
18791887

18801888
# Log theoretical per activation storage sizes
18811889
log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
1882-
fw_module_nodes = {
1890+
fw_module_nodes = OrderedSet(
18831891
node.name for node in fw_module.graph.nodes if node.op == "call_function"
1884-
}
1885-
bw_module_nodes = {
1892+
)
1893+
bw_module_nodes = OrderedSet(
18861894
node.name for node in bw_module.graph.nodes if node.op == "call_function"
1887-
}
1895+
)
18881896
remat_nodes = fw_module_nodes & bw_module_nodes
18891897

18901898
counts: dict[str, int] = defaultdict(int)

0 commit comments

Comments
 (0)