Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 271 additions & 2 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,8 @@ def forward(self, arg152_1):
atol=atol,
rtol=rtol,
)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is more than 1 bmm in this UT. Why the count only increase by 1? Is there some bmm fails the check to use bmm template?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two BMM ops have exactly the same problem dimensions, so the autotune counter is only triggered one time. Only the second BMM op is fused with epilogue nodes, so we only increase the fusion counter by one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no cache mechanism to bypass the autotune of same problem dimensions, cc @jgong5 is there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: The first BMM has an input where stride is not 1 for last dim, so BMM template is not used.

self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 2)

@inductor_config.patch({"freezing": True})
@patches
Expand Down Expand Up @@ -1761,6 +1761,245 @@ def forward(self, x):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (1, 50))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (84, 385))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x @ y

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (1,))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (84,))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_amp(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x @ y

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol), torch.amp.autocast("cpu"):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (1,))
@parametrize("Mdim", (192,))
@parametrize("Kdim", (196,))
@parametrize("Ndim", (64, 65))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_freezing(self, dtype, bs, Mdim, Kdim, Ndim):
class M(torch.nn.Module):
def __init__(self, w):
super().__init__()
self.w = torch.nn.Parameter(w, requires_grad=False)

def forward(self, x):
return x @ self.w

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M(v).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("Ndim", (64, 61))
@parametrize(
"order",
(
((0, 1, 2), (0, 2, 1)), # First BMM in hf_Reformer
((0, 1, 2), (1, 2, 0)), # First BMM in hf_DistilBert
((0, 1, 2), (1, 0, 2)), # Second BMM in hf_DistilBert, hf_T5
((1, 0, 2), (0, 1, 2)), # Third BMM in hf_Reformer
((1, 0, 2), (1, 2, 0)), # First in hf_T5
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_2d_permute(self, Ndim, order, dtype):
# TODO: Support bmm with transposed X
dtype = torch.float
bs = 12
Mdim = 10
Kdim = 62
x_args = (bs, Mdim, Kdim)
w_args = (bs, Kdim, Ndim)
inverse_order = [torch.argsort(torch.tensor(o)).tolist() for o in order]

class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w):
if order[0] != (0, 1, 2):
x_order = [x_args[i] for i in inverse_order[0]]
x = x.reshape(x_order[0], x_order[1] * x_order[2]).clone()
x = x.reshape(*x_order).permute(*order[0])
if order[1] != (0, 1, 2):
w_order = [w_args[i] for i in inverse_order[1]]
w = w.reshape(w_order[0], w_order[1] * w_order[2]).clone()
w = w.reshape(*w_order).permute(*order[1])
y = x @ w
return y

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"],
1 if order[0] == (0, 1, 2) else 0,
)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (64,))
@parametrize("Kdim", (96,))
@dtypes(torch.float, torch.float16, torch.bfloat16)
def test_bmm_self_permute(self, bs, Mdim, Kdim, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x @ x.permute(0, 2, 1)

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (64,))
@dtypes(torch.float)
def test_bmm_self_square(self, bs, Mdim, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x @ x

counters.clear()
u = torch.randn(bs, Mdim, Mdim).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (384,))
@parametrize("Kdim", (96,))
@parametrize("Ndim", (64, 65))
@parametrize(
"epilogue",
(
"relu",
"add",
"sub",
"mul",
"div",
),
)
@dtypes(torch.float32, torch.bfloat16, torch.half)
def test_bmm_with_pointwise(self, bs, Mdim, Kdim, Ndim, epilogue, dtype):
class M(torch.nn.Module):
def __init__(self, epilogue, other):
super().__init__()
self.epilogue = _get_epilogue(epilogue, other)

def forward(self, x, w):
return self.epilogue(x @ w)

counters.clear()
x = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
w = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
other = torch.randn(bs, Mdim, Ndim).to(dtype=dtype)
mod = M(epilogue, other).to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (x, w), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@dtypes(torch.float32, torch.bfloat16, torch.half)
def test_bmm_with_fused_epilogues(self, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mul = torch.randn(8, 8, 3136, 8).as_strided(
(8, 8, 3136, 8), (200704, 8, 64, 1)
)

def forward(self, x, w):
x = torch.ops.aten.reshape.default(x, [64, 3137, 8])
w = torch.ops.aten.reshape.default(w, [64, 8, 8])
bmm = torch.ops.aten.bmm.default(x, w)
bmm = torch.ops.aten.reshape.default(bmm, [8, 8, 3137, 8])
constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
self.mul, [0, 0, 1, 0, 0, 0], 0.0
)
mul_2 = torch.ops.aten.mul.Tensor(bmm, 0.3535533905932738)
add = torch.ops.aten.add.Tensor(mul_2, constant_pad_nd)
return add

counters.clear()
x = torch.randn(8, 8, 3137, 8).to(dtype=dtype)
w = torch.randn(8, 8, 8, 8).to(dtype=dtype)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (x, w), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(BaseTestSelectAlgorithm):
Expand Down Expand Up @@ -1800,6 +2039,36 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
TestSelectAlgorithm.test_linear_thread_factors
)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bs", (5,))
@parametrize("Mdim", (384,))
@parametrize("Kdim", (96,))
@parametrize("Ndim", (64, 65))
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_bmm_with_pointwise_dynamic_shapes(self, bs, Mdim, Kdim, Ndim, dtype):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.epilogue = torch.nn.ReLU()

def forward(self, x, other):
return self.epilogue(x @ other)

counters.clear()
u = torch.randn(bs, Mdim, Kdim).to(dtype=dtype)
v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype)
torch._dynamo.mark_dynamic(u, 0)
torch._dynamo.mark_dynamic(u, 1)
torch._dynamo.mark_static(u, 2)
torch._dynamo.mark_static(v, 2)
mod = M().to(dtype=dtype).eval()
with verify(dtype) as (atol, rtol):
self.common(mod, (u, v), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)


instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
instantiate_device_type_tests(
Expand Down
Loading
Loading