Skip to content

Commit 7f0c5eb

Browse files
Chilleepytorchmergebot
authored andcommitted
Added some more flex attention tests (#125487)
Pull Request resolved: #125487 Approved by: https://github.com/yanboliang
1 parent 6d30803 commit 7f0c5eb

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

test/inductor/test_flex_attention.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,15 @@ def score_mod(score, b, h, m, n):
126126

127127

128128
class TestTemplatedSDPA(InductorTestCase):
129-
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
129+
def run_test(
130+
self,
131+
score_mod: Callable,
132+
dtype: torch.dtype = torch.float16,
133+
B: int = B,
134+
H: int = H,
135+
S: int = S,
136+
D: int = D,
137+
):
130138
sdpa_partial = create_attention(score_mod)
131139
compiled_sdpa = torch.compile(sdpa_partial)
132140
q = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
@@ -289,6 +297,60 @@ def silu_score(score, b, h, q, kv):
289297

290298
self.run_test(silu_score, dtype)
291299

300+
@supported_platform
301+
@common_utils.parametrize("dtype", test_dtypes_fast)
302+
def test_padded_dense_causal(self, dtype):
303+
seq_len = torch.arange(B, device="cuda", dtype=torch.int32) + 1
304+
305+
def create_padded_dense_wrapper(orig_score_mod):
306+
def njt_score_mod(qk, b, h, q, kv):
307+
return torch.where(
308+
qk <= seq_len[b], orig_score_mod(qk, b, h, q, kv), -float("inf")
309+
)
310+
311+
return njt_score_mod
312+
313+
causal_njt = create_padded_dense_wrapper(_causal)
314+
315+
self.run_test(causal_njt, dtype)
316+
317+
@supported_platform
318+
@common_utils.parametrize("dtype", test_dtypes_fast)
319+
def test_captured_scale(self, dtype):
320+
scale = torch.ones((), device="cuda", dtype=torch.int32)
321+
322+
def score_mod_scale(qk, b, h, q, kv):
323+
return qk + scale
324+
325+
self.run_test(score_mod_scale, dtype)
326+
327+
@supported_platform
328+
@common_utils.parametrize("dtype", test_dtypes_fast)
329+
def test_recompile_changed_score_mod(self, dtype):
330+
scale = torch.ones((), device="cuda", dtype=torch.int32)
331+
ADD = True
332+
333+
def score_mod_scale(qk, b, h, q, kv):
334+
if ADD:
335+
return qk + scale
336+
else:
337+
return qk * scale
338+
339+
self.run_test(score_mod_scale, dtype)
340+
ADD = False
341+
self.run_test(score_mod_scale, dtype)
342+
343+
@supported_platform
344+
@expectedFailure # If we capture a tensor then we can perform a reduction on it.
345+
@common_utils.parametrize("dtype", test_dtypes_fast)
346+
def test_captured_reduction(self, dtype):
347+
scale = torch.randn((B, 8), device="cuda")
348+
349+
def score_mod_scale(qk, b, h, q, kv):
350+
return qk + scale[b].sum(dim=-1)
351+
352+
self.run_test(score_mod_scale, dtype)
353+
292354
@supported_platform
293355
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571
294356
@common_utils.parametrize("dtype", test_dtypes)

0 commit comments

Comments
 (0)