@@ -126,7 +126,15 @@ def score_mod(score, b, h, m, n):
126126
127127
128128class TestTemplatedSDPA (InductorTestCase ):
129- def run_test (self , score_mod : Callable , dtype : torch .dtype = torch .float16 ):
129+ def run_test (
130+ self ,
131+ score_mod : Callable ,
132+ dtype : torch .dtype = torch .float16 ,
133+ B : int = B ,
134+ H : int = H ,
135+ S : int = S ,
136+ D : int = D ,
137+ ):
130138 sdpa_partial = create_attention (score_mod )
131139 compiled_sdpa = torch .compile (sdpa_partial )
132140 q = torch .randn ((B , H , S , D ), dtype = dtype , device = "cuda" )
@@ -289,6 +297,60 @@ def silu_score(score, b, h, q, kv):
289297
290298 self .run_test (silu_score , dtype )
291299
300+ @supported_platform
301+ @common_utils .parametrize ("dtype" , test_dtypes_fast )
302+ def test_padded_dense_causal (self , dtype ):
303+ seq_len = torch .arange (B , device = "cuda" , dtype = torch .int32 ) + 1
304+
305+ def create_padded_dense_wrapper (orig_score_mod ):
306+ def njt_score_mod (qk , b , h , q , kv ):
307+ return torch .where (
308+ qk <= seq_len [b ], orig_score_mod (qk , b , h , q , kv ), - float ("inf" )
309+ )
310+
311+ return njt_score_mod
312+
313+ causal_njt = create_padded_dense_wrapper (_causal )
314+
315+ self .run_test (causal_njt , dtype )
316+
317+ @supported_platform
318+ @common_utils .parametrize ("dtype" , test_dtypes_fast )
319+ def test_captured_scale (self , dtype ):
320+ scale = torch .ones ((), device = "cuda" , dtype = torch .int32 )
321+
322+ def score_mod_scale (qk , b , h , q , kv ):
323+ return qk + scale
324+
325+ self .run_test (score_mod_scale , dtype )
326+
327+ @supported_platform
328+ @common_utils .parametrize ("dtype" , test_dtypes_fast )
329+ def test_recompile_changed_score_mod (self , dtype ):
330+ scale = torch .ones ((), device = "cuda" , dtype = torch .int32 )
331+ ADD = True
332+
333+ def score_mod_scale (qk , b , h , q , kv ):
334+ if ADD :
335+ return qk + scale
336+ else :
337+ return qk * scale
338+
339+ self .run_test (score_mod_scale , dtype )
340+ ADD = False
341+ self .run_test (score_mod_scale , dtype )
342+
343+ @supported_platform
344+ @expectedFailure # If we capture a tensor then we can perform a reduction on it.
345+ @common_utils .parametrize ("dtype" , test_dtypes_fast )
346+ def test_captured_reduction (self , dtype ):
347+ scale = torch .randn ((B , 8 ), device = "cuda" )
348+
349+ def score_mod_scale (qk , b , h , q , kv ):
350+ return qk + scale [b ].sum (dim = - 1 )
351+
352+ self .run_test (score_mod_scale , dtype )
353+
292354 @supported_platform
293355 @skip ("Triton bug " ) # https://github.com/pytorch/pytorch/issues/124571
294356 @common_utils .parametrize ("dtype" , test_dtypes )
0 commit comments