Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions test/inductor/test_scatter_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Owner(s): ["module: inductor"]

import copy
import os

import torch
from torch import nn
from torch._dynamo.utils import counters, same
from torch._inductor import metrics
from torch._inductor.runtime.runtime_utils import do_bench_gpu as do_bench
from torch._inductor.test_case import TestCase
from torch.testing._internal.inductor_utils import HAS_GPU

DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"


class TestScatterOpt(TestCase):
def setUp(self):
super().setUp()
metrics.reset()
counters.clear()

def check_metric(self, val=1):
self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just compare metrics against the num_bytes_accessed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_bytes_accessed does not work correctly for mutation. We have a StarDep injected to maintain write-after-write dependencies. That StarDep should not contribute any memory access but we count it when estimate number of memory access.

In spite of the issue mentioned, I think we can still use 'num_bytes_accessed' to verify if the optimization kicks in . When the optimization kicks in, we still access less memory even though the sparse matrix is saved to a buffer in the unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added assertion for metrics.num_bytes_accessed for majority of test cases.

But for test_cross_entropy_loss , it's non-trivial to estimate the amount of memory access manually. And I don't like putting the printed number in unit test directly either since that does not help with understanding.

I think we can probably keep both:

  1. have assertion for num_bytes_accessed
  2. have a specific metric tracking if the optimization is triggered. This is especially useful when the model is a bit complex. It's also useful in production models if we want to know if specific optimizations get triggered. (probably need improve tlparse to expose metrics)

WDYT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shunting314 Don't we already have some infra for checking when patterns get triggered? I think we should be able to reuse that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like this:

            self.assertEqual(counters["inductor"]["pattern_matcher_count"], count)
            self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], nodes)
        

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shunting314 yeah.


def do_acc_test(self, f, *args):
expect = f(*args)
actual = torch.compile(f)(*args)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")

def test_3d_tensor(self):
L, M, N = 2, 1024, 2048

def f(x):
y = torch.full([L, M, N], 3.14, dtype=torch.float)
y.scatter_(2, x.unsqueeze(2), 2.718)
return y

x = torch.randint(0, N, (L, M), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = (
L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize
)
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_non_last_dim(self):
"""
Test the case that the scatter dimension is not the last one.
"""
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(0, x.unsqueeze(0), 2.718)
return y

x = torch.randint(0, M, (N,), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_neg_scatter_dim(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(-1, x.unsqueeze(1), 2.718)
return y

x = torch.randint(0, N, (M,), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_shorter_index_tensor(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(1, x.unsqueeze(1), 2.718)
return y

x = torch.randint(0, N, (M // 2,), dtype=torch.int64)
self.do_acc_test(f, x)

# no match since the index tensor is shorter. May support it in future.
self.assertEqual(0, counters["inductor"]["pattern_matcher_count"])

def test_nonzero_const_tensor(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(1, x.unsqueeze(1), 2.718)
return y

x = torch.randint(0, N, (M,), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_can_not_optimize_due_to_dense(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 0, dtype=torch.float)
y.scatter_(1, x, 0.618)
return y

x = torch.randint(0, N, (M, N // 2), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * (
torch.int64.itemsize + torch.float.itemsize
)
# Use assertGreaterEqual rather than assertEqual due to the issue related
# to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706
self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_can_not_optimize_due_to_non_const(self):
M, N = 1024, 2048

def f(x, y):
y.scatter_(1, x, 0.618)
return y

x = torch.randint(0, N, (M, 1), dtype=torch.int64)
y = torch.randn([M, N])
self.do_acc_test(f, x, y)

# The generated code is quite in-efficient.
# There are 3 kernels
# 1. copy from arg to buf
# 2. scatter upon buf
# 3. copy buf back to arg
# Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40
expected_num_bytes = 4 * M * N * torch.float.itemsize + M * (
torch.int64.itemsize + torch.float.itemsize
)
self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes)

# the second kernel and third kernel are both mutation kernel. So we
# overestimated the memory accessed
# Update the test once the overestimiation is fixed.
over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate)

def test_cross_entropy_loss(self):
"""
Match full+scatter in CEL and replaces it with a pointwise.

Perf data on an A100 GPU:
Without the scatter optimization:
ms=47.340, peak_mem=10.524 GB
With the scatter optimization:
ms=42.768, peak_mem=7.227 GB
"""
B, T, D, V = 32, 1024, 768, 50257
if not DO_PERF_TEST:
# use a smaller V if not doing perf test to avoid OOM
# in CI
V = V // 100
ref_model = nn.Linear(D, V).to(torch.bfloat16)
opt_model = copy.deepcopy(ref_model)
ce = nn.CrossEntropyLoss()

def f(m, x, label):
ce(m(x).view(-1, V), label.view(-1)).backward()

opt_f = torch.compile(f)

x = torch.randn(B, T, D).to(torch.bfloat16)
label = torch.randint(0, V, (B, T)).to(torch.int64)

f(ref_model, x, label)
ref_grad = ref_model.weight.grad
opt_f(opt_model, x, label)
act_grad = opt_model.weight.grad
assert torch.allclose(
ref_grad, act_grad, atol=1e-3, rtol=1e-3
), f"{ref_grad=}\n{act_grad=}"

self.check_metric()

if DO_PERF_TEST:
torch.cuda.reset_peak_memory_stats()
for _ in range(3):
opt_f(opt_model, x, label)
ms = do_bench(lambda: opt_f(opt_model, x, label))
peak_mem = torch.cuda.max_memory_allocated() / 10**9
print(f"{ms=:.3f}, {peak_mem=:.3f} GB")


if HAS_GPU:
torch.set_default_device("cuda")

if __name__ == "__main__":
from torch._inductor.test_case import run_tests

if HAS_GPU:
run_tests()
10 changes: 10 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ def fx_graph_remote_cache_default():
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
developer_warnings = is_fbcode() or is_nightly_or_source

# This pattern matches a special usage of scatter
# 1. It's applied to a constant tensor
# 2. The index tensor has size 1 in the scatter dimension
# Such pattern generates a sparse matrix when the const tensor is all-zero.
# We can lower this pattern to a pointwise kernel for more fusion opportunities
# and saving memory footprint.
optimize_scatter_upon_const_tensor = (
os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
)


# The multiprocessing start method to use for inductor workers in the codecache.
# "subprocess", "fork", or "spawn"
Expand Down
83 changes: 83 additions & 0 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import fx
from torch._decomp import register_decomposition
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._inductor.virtualized import ops

from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype

Expand Down Expand Up @@ -216,6 +217,88 @@ def is_valid_mm_plus_mm(match: Match):
return True


def scatter_upon_const_tensor_extra_check(m):
if not config.optimize_scatter_upon_const_tensor:
return False
full_shape = m.kwargs["shape"]
selector = m.kwargs["selector"]
dim = m.kwargs["dim"]
if dim < 0:
dim += len(full_shape)

selector_ft = selector.meta["val"]
assert selector_ft.dim() == len(full_shape)

for idx, select_sz, full_sz in zip(
itertools.count(), selector_ft.shape, full_shape
):
if idx == dim:
continue

# TODO: the pattern can be updated to support the case that index tensor
# is shorter. But that will need a more complex condition expression
# especially for multi-dimensional tensors.
# Skip it for now.
if isinstance(full_sz, fx.Node):
full_sz = full_sz.meta["val"]
if select_sz < full_sz:
return False

# Actually we can support small size larger than 1. It would be a bit
# tedius. E.g., we load all the index values (not many) and compare
# them with the position in tensor to decide what value to return.
return selector_ft.size(dim) == 1


@register_lowering_pattern(
CallFunction(
aten.scatter.value,
CallFunction(
aten.full,
KeywordArg("shape"),
KeywordArg("background_val"),
dtype=KeywordArg("dtype"),
),
KeywordArg("dim"),
KeywordArg("selector"),
KeywordArg("val"), # scalar value
),
extra_check=scatter_upon_const_tensor_extra_check,
)
def scatter_upon_const_tensor(
match: Match, shape, background_val, dtype, dim, selector, val
):
"""
Match the pattern of full+scatter into a pointwise.

TODO: Right now the scatter value must be a scalar. But we could support it
when it is a tensor as well.
"""
from torch._inductor import metrics

metrics.num_matches_for_scatter_upon_const_tensor += 1

selector_loader = selector.make_loader()

def inner_fn(idx):
selector_idx = list(idx)
selector_idx[dim] = 0

selector = selector_loader(selector_idx)
return ops.where(
selector == ops.index_expr(idx[dim], torch.int64),
ops.constant(val, dtype),
ops.constant(background_val, dtype),
)

return ir.Pointwise.create(
device=selector.get_device(),
dtype=dtype,
inner_fn=inner_fn,
ranges=shape,
)


@register_lowering_pattern(
CallFunction(
aten.add,
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
cpp_outer_loop_fused_inner_counts: List[int] = []

num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is kinda weird lol. Oddly specific thing to add a counter for imo.



# reset all counters
Expand All @@ -57,6 +58,7 @@ def reset():
global cpp_to_dtype_count
global cpp_outer_loop_fused_inner_counts
global num_comprehensive_padding
global num_matches_for_scatter_upon_const_tensor

generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
Expand All @@ -67,6 +69,7 @@ def reset():
cpp_to_dtype_count = 0
cpp_outer_loop_fused_inner_counts.clear()
num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0


@dataclass
Expand Down