Skip to content

Commit 80c7c71

Browse files
drisspgpytorchmergebot
authored andcommitted
Make sure all SDPA tests are ran with tensor cores enabled (#135592)
Pull Request resolved: #135592 Approved by: https://github.com/eqy
1 parent c81d4fd commit 80c7c71

File tree

2 files changed

+75
-22
lines changed

2 files changed

+75
-22
lines changed

test/test_transformers.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@
4545

4646
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
4747
from torch.testing._internal.common_cuda import (
48-
IS_JETSON, SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
48+
IS_JETSON,
49+
SM80OrLater,
50+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
4951
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
5052
PLATFORM_SUPPORTS_FUSED_ATTENTION,
5153
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
5254
SM90OrLater,
53-
tf32_on_and_off
55+
tf32_on_and_off,
56+
tf32_enabled,
5457
)
5558

5659
if not IS_FBCODE:
@@ -64,6 +67,7 @@
6467
SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim'])
6568
Tolerances = namedtuple('Tolerances', ['atol', 'rtol'])
6669

70+
6771
@contextlib.contextmanager
6872
def use_deterministic_algorithims(mode: bool, warn_only: bool):
6973
r"""
@@ -2998,17 +3002,30 @@ def test_mem_eff_backwards_determinism(self, device):
29983002
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
29993003
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
30003004
@parametrize("batch_size", [1, 8])
3001-
@parametrize("seq_len_q", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3002-
else [4, 8, 256, 512])
3003-
@parametrize("seq_len_k", [8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3004-
else [4, 8, 256, 512])
3005-
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
3006-
else [8, 16, 32, 64])
3005+
@parametrize(
3006+
"seq_len_q",
3007+
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
3008+
)
3009+
@parametrize(
3010+
"seq_len_k",
3011+
[8, 103, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [4, 8, 256, 512],
3012+
)
3013+
@parametrize(
3014+
"head_dim",
3015+
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
3016+
)
30073017
@parametrize("is_causal", [False, True])
30083018
@parametrize("dropout_p", [0.0, 0.22])
3009-
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
3010-
else [torch.float16, torch.float32])
3019+
@parametrize(
3020+
"dtype",
3021+
(
3022+
[torch.float16, torch.bfloat16, torch.float32]
3023+
if MEM_EFF_CAPABILITY_MATCHES_SM80
3024+
else [torch.float16, torch.float32]
3025+
),
3026+
)
30113027
@parametrize("scale", [None, "l1"])
3028+
@tf32_enabled()
30123029
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
30133030
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
30143031
scale: str):
@@ -3097,17 +3114,30 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
30973114
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
30983115
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
30993116
@parametrize("batch_size", [1, 8])
3100-
@parametrize("seq_len_q", [8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3101-
else [8, 152, 512])
3102-
@parametrize("seq_len_k", [8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80
3103-
else [8, 37, 512])
3104-
@parametrize("head_dim", [8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80
3105-
else [8, 16, 32, 64])
3117+
@parametrize(
3118+
"seq_len_q",
3119+
[8, 312, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 152, 512],
3120+
)
3121+
@parametrize(
3122+
"seq_len_k",
3123+
[8, 408, 1024, 2048] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 37, 512],
3124+
)
3125+
@parametrize(
3126+
"head_dim",
3127+
[8, 16, 96, 128] if MEM_EFF_CAPABILITY_MATCHES_SM80 else [8, 16, 32, 64],
3128+
)
31063129
@parametrize("is_causal", [False])
31073130
@parametrize("dropout_p", [0.0, 0.22])
3108-
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if MEM_EFF_CAPABILITY_MATCHES_SM80
3109-
else [torch.float16, torch.float32])
3131+
@parametrize(
3132+
"dtype",
3133+
(
3134+
[torch.float16, torch.bfloat16, torch.float32]
3135+
if MEM_EFF_CAPABILITY_MATCHES_SM80
3136+
else [torch.float16, torch.float32]
3137+
),
3138+
)
31103139
@parametrize("scale", [None, "l1"])
3140+
@tf32_enabled()
31113141
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
31123142
seq_len_k: int, head_dim: int, is_causal: bool,
31133143
dropout_p: float, dtype: torch.dtype,
@@ -3137,7 +3167,6 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
31373167

31383168
attn_mask = torch.rand(seq_len_q, seq_len_k, device=device, dtype=dtype, requires_grad=True)
31393169

3140-
31413170
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
31423171
query_ref, key_ref, value_ref = query_key_value_clones(query, key, value, dtype=higher_precision_dtype)
31433172
attn_mask_ref = attn_mask.detach().to(higher_precision_dtype).requires_grad_(True)
@@ -3204,7 +3233,10 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
32043233
fudge_factors=fudge_factors,
32053234
)
32063235

3207-
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3236+
@unittest.skipIf(
3237+
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
3238+
"Does not support SDPA or pre-SM80 hardware",
3239+
)
32083240
@unittest.skipIf(IS_JETSON, "causing sigkill on Jetson")
32093241
@parametrize("batch_size", [1, 8])
32103242
@parametrize("seq_len_q", [4, 143, 2048])
@@ -3216,6 +3248,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
32163248
@parametrize("scale", [None, "l1"])
32173249
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
32183250
@parametrize("n_heads", [[16, 8], [10, 2]])
3251+
@tf32_enabled()
32193252
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
32203253
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
32213254
scale: str, enable_gqa: bool, n_heads: List[int]):
@@ -3327,7 +3360,10 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
33273360
fudge_factors=fudge_factors,
33283361
)
33293362

3330-
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
3363+
@unittest.skipIf(
3364+
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
3365+
"Does not support SDPA or pre-SM80 hardware",
3366+
)
33313367
@parametrize("batch_size", [1, 8])
33323368
@parametrize("seq_len_q", [256, 1024])
33333369
@parametrize("seq_len_k", [256, 1024])
@@ -3337,6 +3373,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
33373373
@parametrize("dtype", [torch.float16])
33383374
@parametrize("scale", [None, "l1"])
33393375
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
3376+
@tf32_enabled()
33403377
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
33413378
seq_len_q: int, seq_len_k: int,
33423379
head_dim: int,
@@ -3479,7 +3516,6 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d
34793516
}
34803517
)
34813518

3482-
34833519
@skipIfRocm # Nested Tensor
34843520
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
34853521
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if

torch/testing/_internal/common_cuda.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ def tf32_on(self, tf32_precision=1e-5):
154154
self.precision = old_precision
155155

156156

157+
@contextlib.contextmanager
158+
def tf32_enabled():
159+
"""
160+
Context manager to temporarily enable TF32 for CUDA operations.
161+
Restores the previous TF32 state after exiting the context.
162+
"""
163+
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
164+
try:
165+
torch.backends.cuda.matmul.allow_tf32 = True
166+
with torch.backends.cudnn.flags(
167+
enabled=None, benchmark=None, deterministic=None, allow_tf32=True
168+
):
169+
yield
170+
finally:
171+
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
172+
173+
157174
# This is a wrapper that wraps a test to run this test twice, one with
158175
# allow_tf32=True, another with allow_tf32=False. When running with
159176
# allow_tf32=True, it will use reduced precision as specified by the

0 commit comments

Comments
 (0)