4545
4646from torch.testing._internal.common_methods_invocations import wrapper_set_seed
4747from torch.testing._internal.common_cuda import (
48- IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
48+ IS_JETSON,
49+ SM80OrLater,
50+ PLATFORM_SUPPORTS_FLASH_ATTENTION,
4951 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
5052 PLATFORM_SUPPORTS_FUSED_ATTENTION,
5153 PLATFORM_SUPPORTS_CUDNN_ATTENTION,
5254 SM90OrLater,
53- tf32_on_and_off
55+ tf32_on_and_off,
56+ tf32_enabled,
5457)
5558
5659if not IS_FBCODE:
6467SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
6568Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
6669
70+
6771@contextlib.contextmanager
6872def use_deterministic_algorithims(mode: bool, warn_only: bool):
6973 r"""
@@ -2998,17 +3002,30 @@ def test_mem_eff_backwards_determinism(self, device):
29983002 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
29993003 @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
30003004 @parametrize("batch_size", [1, 8])
3001- @parametrize("seq_len_q", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3002- else [4, 8, 256, 512])
3003- @parametrize("seq_len_k", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3004- else [4, 8, 256, 512])
3005- @parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
3006- else [8, 16, 32, 64])
3005+ @parametrize(
3006+ "seq_len_q",
3007+ [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
3008+ )
3009+ @parametrize(
3010+ "seq_len_k",
3011+ [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
3012+ )
3013+ @parametrize(
3014+ "head_dim",
3015+ [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
3016+ )
30073017 @parametrize("is_causal", [False, True])
30083018 @parametrize("dropout_p", [0.0, 0.22])
3009- @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
3010- else [torch.float16, torch.float32])
3019+ @parametrize(
3020+ "dtype",
3021+ (
3022+ [torch.float16, torch.bfloat16, torch.float32]
3023+ if MEM_EFF_CAPABILITY_MATCHES_SM80
3024+ else [torch.float16, torch.float32]
3025+ ),
3026+ )
30113027 @parametrize("scale", [None, "l1"])
3028+ @tf32_enabled()
30123029 def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
30133030 head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
30143031 scale: str):
@@ -3097,17 +3114,30 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
30973114 @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
30983115 @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
30993116 @parametrize("batch_size", [1, 8])
3100- @parametrize("seq_len_q", [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3101- else [8, 152, 512])
3102- @parametrize("seq_len_k", [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3103- else [8, 37, 512])
3104- @parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
3105- else [8, 16, 32, 64])
3117+ @parametrize(
3118+ "seq_len_q",
3119+ [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 152, 512],
3120+ )
3121+ @parametrize(
3122+ "seq_len_k",
3123+ [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 37, 512],
3124+ )
3125+ @parametrize(
3126+ "head_dim",
3127+ [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
3128+ )
31063129 @parametrize("is_causal", [False])
31073130 @parametrize("dropout_p", [0.0, 0.22])
3108- @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
3109- else [torch.float16, torch.float32])
3131+ @parametrize(
3132+ "dtype",
3133+ (
3134+ [torch.float16, torch.bfloat16, torch.float32]
3135+ if MEM_EFF_CAPABILITY_MATCHES_SM80
3136+ else [torch.float16, torch.float32]
3137+ ),
3138+ )
31103139 @parametrize("scale", [None, "l1"])
3140+ @tf32_enabled()
31113141 def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
31123142 seq_len_k: int, head_dim: int, is_causal: bool,
31133143 dropout_p: float, dtype: torch.dtype,
@@ -3137,7 +3167,6 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
31373167
31383168 attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
31393169
3140-
31413170 higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
31423171 query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
31433172 attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
@@ -3204,7 +3233,10 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
32043233 fudge_factors=fudge_factors,
32053234 )
32063235
3207- @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3236+ @unittest.skipIf(
3237+ not PLATFORM_SUPPORTS_FLASH_ATTENTION,
3238+ "Does not support SDPA or pre-SM80 hardware",
3239+ )
32083240 @unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
32093241 @parametrize("batch_size", [1, 8])
32103242 @parametrize("seq_len_q", [4, 143, 2048])
@@ -3216,6 +3248,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
32163248 @parametrize("scale", [None, "l1"])
32173249 @parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
32183250 @parametrize("n_heads", [[16, 8], [10, 2]])
3251+ @tf32_enabled()
32193252 def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
32203253 head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
32213254 scale: str, enable_gqa: bool, n_heads: List[int]):
@@ -3327,7 +3360,10 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
33273360 fudge_factors=fudge_factors,
33283361 )
33293362
3330- @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3363+ @unittest.skipIf(
3364+ not PLATFORM_SUPPORTS_FLASH_ATTENTION,
3365+ "Does not support SDPA or pre-SM80 hardware",
3366+ )
33313367 @parametrize("batch_size", [1, 8])
33323368 @parametrize("seq_len_q", [256, 1024])
33333369 @parametrize("seq_len_k", [256, 1024])
@@ -3337,6 +3373,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
33373373 @parametrize("dtype", [torch.float16])
33383374 @parametrize("scale", [None, "l1"])
33393375 @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
3376+ @tf32_enabled()
33403377 def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
33413378 seq_len_q: int, seq_len_k: int,
33423379 head_dim: int,
@@ -3479,7 +3516,6 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d
34793516 }
34803517 )
34813518
3482-
34833519 @skipIfRocm # Nested Tensor
34843520 @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
34853521 @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
0 commit comments