Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8089,7 +8089,9 @@ def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.b
return torch.matmul(x, w)
return (
torch.matmul(x, w) + self.b + torch.arange(4, dtype=torch.float16)
)

example_inputs = (torch.randn(4, 4),)
mod = Model()
Expand All @@ -8105,17 +8107,38 @@ def forward(self, x):
for n, spec in zip(placeholder_nodes, new_sig.input_specs)
if spec.target is not None
}
const_gm, _ = split_const_gm(new_gm, lifted_constants)
# [self.w_pre, self.b]
lifted_constant_names = list(lifted_constants)
lifted_constant_values = [lifted_constants[n] for n in lifted_constant_names]
const_gm, _ = split_const_gm(new_gm, False, lifted_constant_names)
counter = 0
for node in const_gm.graph.nodes:
if node.op == "call_function":
counter += 1
self.assertTrue(counter > 0)
self.assertTrue(counter == 4)
counter = 0
for n in new_gm.graph.nodes:
if n.op == "placeholder":
counter += 1
# expect 3 existing placeholders and 2 folded constant
self.assertTrue(counter == 5)
# return (self.b, folded_const, folded_const)
const_folded_value = const_gm(*lifted_constant_values)

test_input = torch.randn(4, 4)
expected = new_gm(None, None, test_input)[0]
actual = mod(test_input)
# new_gm(c_w_pre, b, x, folded_const, folded_const)
actual = new_gm(
lifted_constant_values[0],
const_folded_value[0],
test_input,
const_folded_value[1],
const_folded_value[2],
)[0]
expected = mod(test_input)
self.assertEqual(actual, expected)
const_gm, _ = split_const_gm(ep.graph_module, lifted_constants, lambda x: True)
const_gm, _ = split_const_gm(
ep.graph_module, False, lifted_constant_names, lambda x: True
)
counter = 0
for node in const_gm.graph.nodes:
if node.op == "call_function":
Expand Down
14 changes: 10 additions & 4 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->

def split_const_gm(
gm: GraphModule,
lifted_constants: Optional[Dict[str, Any]] = None,
skip_constructor: bool = True,
lifted_constant_names: Optional[List[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> Tuple[GraphModule, Dict[str, int]]:
"""
Expand All @@ -362,9 +363,10 @@ def split_const_gm(
run_and_get_constant_graph,
)

const_gm, const_result = run_and_get_constant_graph(
gm, lifted_constants, skip_folding_node_fn
const_gm = run_and_get_constant_graph(
gm, skip_constructor, lifted_constant_names, skip_folding_node_fn
)
const_result = const_gm() if lifted_constant_names is None else None

const_outputs = {
x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0])
Expand All @@ -384,7 +386,11 @@ def split_const_gm(
replace_node_with_constant(
gm,
node,
const_result[const_outputs[node.name]],
(
const_result[const_outputs[node.name]]
if lifted_constant_names is None
else None
),
new_const_name,
)
const_output_index[new_const_name] = const_outputs[node.name]
Expand Down
103 changes: 58 additions & 45 deletions torch/_inductor/constant_folding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import collections
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.utils._pytree as pytree
Expand All @@ -18,7 +18,7 @@
def replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: torch.Tensor,
constant: Optional[torch.Tensor] = None,
name: Optional[str] = None,
) -> None:
g = gm.graph
Expand All @@ -39,32 +39,33 @@ def replace_node_with_constant(
gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
if constant is not None:
new_input_node = g.create_node("get_attr", qualname, (), {})
else:
# this is the case for lifted constants
new_input_node = g.create_node("placeholder", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)
if constant is not None:
# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)


def is_const_source(
node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]
node: torch.fx.Node, lifted_constant_names: Optional[List[str]]
) -> bool:
return node.op == "get_attr" or (
node.op == "placeholder"
and lifted_constants is not None
and node.name in lifted_constants
)
return node.op == "get_attr" or node.name in (lifted_constant_names or ())


class ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
lifted_constant_names: Optional[List[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
super().__init__(gm)
Expand All @@ -76,14 +77,27 @@ def __init__(
# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
self.lifted_constants = lifted_constants
self.lifted_constant_names = lifted_constant_names
self.deferred_value = object()

def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False

def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)
if self.lifted_constant_names is None:
return super().run_node(node)
# if lifted_constant_names is passed in, no concrete value is available
# so we just check if all inputs have values
flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
for inp in flattened_node_inps:
if (
isinstance(inp, torch.fx.Node)
and inp.name not in (self.lifted_constant_names or ())
and self.env[inp] != self.deferred_value
):
return self.unknown_value
return self.deferred_value

def is_impure(self, node: torch.fx.node.Node) -> bool:
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
Expand All @@ -103,7 +117,7 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
and is_woq_int8_pattern(next(iter(node.users)))
)
) and is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
node.args[0], self.lifted_constant_names # type: ignore[arg-type]
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
Expand Down Expand Up @@ -191,7 +205,7 @@ def set_env(arg: torch.fx.Node) -> None:
# TODO - more complicated strategy
if (
self.skip_constructors
and not is_const_source(node, self.lifted_constants)
and not is_const_source(node, self.lifted_constant_names)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value
Expand All @@ -207,10 +221,10 @@ def set_env(arg: torch.fx.Node) -> None:
if out == self.unknown_value:
return self.unknown_value

if not is_const_source(node, self.lifted_constants) and isinstance(
out, torch.Tensor
if not is_const_source(node, self.lifted_constant_names) and (
isinstance(out, torch.Tensor) or out == self.deferred_value
):
if out.device.type == "meta":
if out != self.deferred_value and out.device.type == "meta":
return out

if not self.insertable_tensor_check(out):
Expand Down Expand Up @@ -248,10 +262,12 @@ def run(self) -> Any: # type: ignore[override]

def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
if self.lifted_constants is not None and n.name in self.lifted_constants:
env[n] = self.lifted_constants[n.name]
else:
env[n] = self.unknown_value # type: ignore[assignment]
env[n] = self.unknown_value # type: ignore[assignment]
if self.lifted_constant_names is None:
return
for n in self.module.graph.nodes:
if n.name in (self.lifted_constant_names or ()):
env[n] = self.deferred_value


def constant_fold(
Expand Down Expand Up @@ -284,12 +300,15 @@ def constant_fold(

def constant_graph_tag(
gm: torch.fx.GraphModule,
lifted_constants: Optional[Dict[str, Any]],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
skip_constructors: bool = True,
lifted_constant_names: Optional[List[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(
gm, skip_constructors=True, lifted_constants=lifted_constants
gm,
skip_constructors=skip_constructors,
lifted_constant_names=lifted_constant_names,
)
cf.run()

Expand All @@ -298,7 +317,7 @@ def constant_graph_tag(
node.meta[META_TAG] = MODULE_TAG
continue
if (
is_const_source(node, lifted_constants)
is_const_source(node, lifted_constant_names)
or node in cf.node_replacements
or node in cf.replaced_uses
):
Expand All @@ -309,15 +328,18 @@ def constant_graph_tag(

def run_and_get_constant_graph(
gm: torch.fx.GraphModule,
lifted_constants: Optional[Dict[str, Any]],
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]],
) -> Tuple[torch.fx.GraphModule, Tuple[torch.Tensor, ...]]:
skip_constructors: bool = True,
lifted_constant_names: Optional[List[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> torch.fx.GraphModule:
"""
Construct a GraphModule which corresponds to the part which could be
constant folded in provided gm.
"""

constant_graph_tag(gm, lifted_constants, skip_folding_node_fn)
constant_graph_tag(
gm, skip_constructors, lifted_constant_names, skip_folding_node_fn
)

def untag(node: torch.fx.Node) -> bool:
used_to_fold = False
Expand All @@ -329,19 +351,11 @@ def untag(node: torch.fx.Node) -> bool:
node.meta[META_TAG] = MODULE_TAG
return used_to_fold

const_args = []
if lifted_constants is not None:
placeholders = list(gm.graph.find_nodes(op="placeholder"))
for node in placeholders:
if node.meta[META_TAG] == MODULE_TAG:
continue
if untag(node):
const_args.append(lifted_constants[node.name])

# We rewrite the tags, if it's a constant being directly consumed, without
# any folding opportunity, we keep it in main gm.
for node in gm.graph.find_nodes(op="get_attr"):
untag(node)
for node in gm.graph.nodes:
if node.op == "getattr" or (node.name in (lifted_constant_names or ())):
untag(node)

new_graph = torch.fx.Graph()

Expand All @@ -363,5 +377,4 @@ def untag(node: torch.fx.Node) -> bool:
new_graph.lint()
new_gm = torch.fx.GraphModule(gm, new_graph)

const_result = new_gm(*const_args)
return new_gm, const_result
return new_gm
Loading