2525 is_symbol_binding_fx_node ,
2626)
2727from torch .fx .passes import graph_drawer
28+ from torch .utils ._ordered_set import OrderedSet
2829from torch .utils .checkpoint import CheckpointPolicy
2930
3031from . import config
5556class OpTypes :
5657 """Class for keeping track of different operator categories"""
5758
58- fusible_ops : set [Callable ]
59- compute_intensive_ops : set [Callable ]
60- random_ops : set [Callable ]
61- view_ops : set [Callable ]
62- recomputable_ops : set [Callable ]
59+ fusible_ops : OrderedSet [Callable ]
60+ compute_intensive_ops : OrderedSet [Callable ]
61+ random_ops : OrderedSet [Callable ]
62+ view_ops : OrderedSet [Callable ]
63+ recomputable_ops : OrderedSet [Callable ]
6364
6465 def is_fusible (self , node : fx .Node ):
6566 return get_aten_target (node ) in self .fusible_ops
@@ -82,9 +83,9 @@ class NodeInfo:
8283 # Be careful about iterating over these explicitly, as their order may not
8384 # be deterministic
8485 inputs : list [fx .Node ]
85- _required_fw_nodes : set [fx .Node ]
86- required_bw_nodes : set [fx .Node ]
87- unclaimed_nodes : set [fx .Node ]
86+ _required_fw_nodes : OrderedSet [fx .Node ]
87+ required_bw_nodes : OrderedSet [fx .Node ]
88+ unclaimed_nodes : OrderedSet [fx .Node ]
8889 fw_order : dict [fx .Node , int ]
8990
9091 @functools .cached_property
@@ -326,7 +327,7 @@ def _extract_fwd_bwd_modules(
326327 # we propagate all symbols which are referenced by backwards inputs.
327328 # These are not directly used in the graph but are required for downstream
328329 # sizevar assignment
329- saved_symbols : set [sympy .Symbol ] = set ()
330+ saved_symbols : OrderedSet [sympy .Symbol ] = OrderedSet ()
330331 saved_sym_nodes_binding = []
331332 saved_sym_nodes_derived = []
332333
@@ -426,9 +427,9 @@ def default_partition(
426427 forward_only_graph = _extract_graph_with_inputs_outputs (
427428 joint_module .graph , inputs , fwd_outputs , "forward"
428429 )
429- forward_node_names = {
430+ forward_node_names = OrderedSet (
430431 node .name for node in forward_only_graph .nodes if node .op != "output"
431- }
432+ )
432433 saved_values = []
433434 saved_sym_nodes = []
434435
@@ -580,7 +581,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
580581
581582 def insert_node_in_graph (node ):
582583 cur_nodes = [node ]
583- insertable_nodes = set ()
584+ insertable_nodes : OrderedSet [ fx . Node ] = OrderedSet ()
584585 while len (cur_nodes ) > 0 :
585586 node = cur_nodes .pop ()
586587 if node in insertable_nodes or node in env :
@@ -817,19 +818,21 @@ def solve_min_cut(
817818 joint_graph : fx .Graph ,
818819 node_info : NodeInfo ,
819820 min_cut_options : MinCutOptions ,
820- dont_ban = None ,
821+ dont_ban : Optional [ OrderedSet [ fx . Node ]] = None ,
821822):
822823 if dont_ban is None :
823- dont_ban = set ()
824+ dont_ban = OrderedSet ()
824825 op_types = get_default_op_list ()
825826
826827 if AOT_PARTITIONER_DEBUG :
827- joint_module_ops = {
828+ joint_module_ops = OrderedSet (
828829 str (node .target ._overloadpacket )
829830 for node in joint_graph .nodes
830831 if node .op == "call_function" and hasattr (node .target , "_overloadpacket" )
831- }
832- ops_ignored = joint_module_ops - {str (i ) for i in op_types .recomputable_ops }
832+ )
833+ ops_ignored = joint_module_ops - OrderedSet (
834+ str (i ) for i in op_types .recomputable_ops
835+ )
833836 log .info ("Ops banned from re-materialization: %s" , ops_ignored )
834837
835838 def can_fuse_into_auto_functionalized (a , b ):
@@ -888,7 +891,7 @@ def is_fusible(a, b):
888891 def is_materialized_backwards (node ):
889892 if op_types .is_view (node ):
890893 return False
891- cur_nodes = { node }
894+ cur_nodes = OrderedSet ([ node ])
892895 while len (cur_nodes ) > 0 :
893896 cur = cur_nodes .pop ()
894897 for user in cur .users :
@@ -981,7 +984,7 @@ def get_node_weight(node) -> float:
981984 return mem_sz * 2
982985
983986 nx_graph = nx .DiGraph ()
984- banned_nodes = set ()
987+ banned_nodes : OrderedSet [ fx . Node ] = OrderedSet ()
985988
986989 def ban_recomputation_if_allowed (node ):
987990 if op_types .is_view (node ):
@@ -1091,12 +1094,13 @@ def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
10911094 if node_info .is_required_fw (user ):
10921095 if node_info .get_fw_order (user ) > max_range :
10931096 continue
1094- val = (node_info .get_fw_order (user ), user , is_fusible (node , user ))
1097+ val : tuple [int , fx .Node , bool ] = (
1098+ node_info .get_fw_order (user ),
1099+ user ,
1100+ is_fusible (node , user ),
1101+ )
10951102 if val not in sorted_nodes :
1096- heapq .heappush (
1097- sorted_nodes ,
1098- val ,
1099- )
1103+ heapq .heappush (sorted_nodes , val )
11001104 return max_range
11011105
11021106 if min_cut_options .ban_if_used_far_apart :
@@ -1141,11 +1145,13 @@ def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
11411145 # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
11421146
11431147 if min_cut_options .ban_if_long_fusible_chains :
1144- visited = set ()
1148+ visited : OrderedSet [ fx . Node ] = OrderedSet ()
11451149 for start_node in joint_graph .nodes :
11461150 if not node_info .is_required_fw (start_node ):
11471151 continue
1148- fusible = [(node_info .get_fw_order (start_node ), start_node )]
1152+ fusible : list [tuple [int , fx .Node ]] = [
1153+ (node_info .get_fw_order (start_node ), start_node )
1154+ ]
11491155 start_order = node_info .get_fw_order (start_node )
11501156 while len (fusible ) > 0 :
11511157 _ , cur = heapq .heappop (fusible )
@@ -1184,11 +1190,11 @@ def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
11841190 raise
11851191
11861192 reachable , non_reachable = partition
1187- cutset : set [tuple [str , str ]] = set ()
1193+ cutset : OrderedSet [tuple [str , str ]] = OrderedSet ()
11881194 for u , nbrs in ((n , nx_graph [n ]) for n in reachable ):
11891195 cutset .update ((u , v ) for v in nbrs if v in non_reachable )
11901196
1191- cut_nodes = set ()
1197+ cut_nodes : OrderedSet [ str ] = OrderedSet ()
11921198 for node_in , node_out in cutset :
11931199 assert node_in [:- 3 ] == node_out [:- 4 ]
11941200 node_name = node_in [:- 3 ]
@@ -1358,9 +1364,9 @@ def get_default_op_list() -> OpTypes:
13581364 ]
13591365
13601366 default_recomputable_ops += [method_to_operator (m ) for m in magic_methods ]
1361- recomputable_ops = set (default_recomputable_ops )
1367+ recomputable_ops = OrderedSet (default_recomputable_ops )
13621368
1363- random_ops = [aten .native_dropout , aten .rand_like , aten .randn_like ]
1369+ random_ops = OrderedSet ( [aten .native_dropout , aten .rand_like , aten .randn_like ])
13641370 compute_intensive_ops = [
13651371 aten .mm ,
13661372 aten .convolution ,
@@ -1375,13 +1381,13 @@ def get_default_op_list() -> OpTypes:
13751381 aten ._scaled_mm ,
13761382 ] # noqa: E501,B950
13771383
1378- fusible_ops = recomputable_ops | set ( random_ops )
1384+ fusible_ops = recomputable_ops | random_ops
13791385 return OpTypes (
1380- set ( fusible_ops ) ,
1381- set (compute_intensive_ops ),
1382- set ( random_ops ) ,
1383- set (view_ops ),
1384- set ( recomputable_ops ) ,
1386+ fusible_ops ,
1387+ OrderedSet (compute_intensive_ops ),
1388+ random_ops ,
1389+ OrderedSet (view_ops ),
1390+ recomputable_ops ,
13851391 )
13861392
13871393
@@ -1567,9 +1573,11 @@ def get_mem_ratio(activations: list[fx.Node]):
15671573
15681574 from torch ._inductor .fx_utils import get_node_storage
15691575
1570- input_storages = { get_node_storage (node ) for node in node_info .inputs }
1576+ input_storages = OrderedSet ( get_node_storage (node ) for node in node_info .inputs )
15711577
1572- def get_recomputable_banned_nodes (banned_nodes : set [fx .Node ]) -> list [fx .Node ]:
1578+ def get_recomputable_banned_nodes (
1579+ banned_nodes : OrderedSet [fx .Node ],
1580+ ) -> list [fx .Node ]:
15731581 return [
15741582 i
15751583 for i in banned_nodes
@@ -1653,7 +1661,7 @@ def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
16531661 payload_fn = lambda : knapsack_summary ,
16541662 )
16551663 log .info (knapsack_summary )
1656- dont_ban = set ()
1664+ dont_ban : OrderedSet [ fx . Node ] = OrderedSet ()
16571665 for idx in recomputable_node_idxs :
16581666 # if idx in all_recomputable_banned_nodes:
16591667 try :
@@ -1776,7 +1784,7 @@ def min_cut_rematerialization_partition(
17761784
17771785 def classify_nodes (joint_module ):
17781786 name_to_node = get_name_to_node (joint_module .graph )
1779- required_bw_nodes = set ()
1787+ required_bw_nodes : OrderedSet [ fx . Node ] = OrderedSet ()
17801788 for node in joint_module .graph .nodes :
17811789 if node .op == "placeholder" and "tangents" in node .target :
17821790 required_bw_nodes .add (node )
@@ -1800,16 +1808,16 @@ def classify_nodes(joint_module):
18001808 forward_only_graph = _extract_graph_with_inputs_outputs (
18011809 joint_module .graph , inputs , fwd_outputs , "forward"
18021810 )
1803- required_fw_nodes : set [fx .Node ] = {
1811+ required_fw_nodes : OrderedSet [fx .Node ] = OrderedSet (
18041812 name_to_node [node .name ]
18051813 for node in forward_only_graph .nodes
18061814 if node .op != "output"
1807- }
1808- unclaimed_nodes = {
1815+ )
1816+ unclaimed_nodes : OrderedSet [ fx . Node ] = OrderedSet (
18091817 node
18101818 for node in joint_module .graph .nodes
18111819 if node not in required_fw_nodes and node not in required_bw_nodes
1812- }
1820+ )
18131821 fw_cnt = 0
18141822 fw_order = {}
18151823 for node in joint_module .graph .nodes :
@@ -1879,12 +1887,12 @@ def classify_nodes(joint_module):
18791887
18801888 # Log theoretical per activation storage sizes
18811889 log .info ("Theoretical Per Activation Storage Sizes: %s" , sorted_sizes )
1882- fw_module_nodes = {
1890+ fw_module_nodes = OrderedSet (
18831891 node .name for node in fw_module .graph .nodes if node .op == "call_function"
1884- }
1885- bw_module_nodes = {
1892+ )
1893+ bw_module_nodes = OrderedSet (
18861894 node .name for node in bw_module .graph .nodes if node .op == "call_function"
1887- }
1895+ )
18881896 remat_nodes = fw_module_nodes & bw_module_nodes
18891897
18901898 counts : dict [str , int ] = defaultdict (int )
0 commit comments