Skip to content

Commit ce4d951

Browse files
Chilleepytorchmergebot
authored andcommitted
Add scale kwarg to FlexAttention (and some changes that get FlexAttention numerics to be as accurate as FA2) (#130250)
After this PR, our numerical error is within 3% of FA2 for forward and gradients. Prior, for `dq` our numerical error was 30% higher. I also added a `PRESCALE_QK` kernel option that increases perf by about 3-4% but incurs about 20-30% more numerical error. ![image](https://github.com/pytorch/pytorch/assets/6355099/7b5ff44e-219b-4a05-8a1b-2a0182c01ab2) Pull Request resolved: #130250 Approved by: https://github.com/drisspg ghstack dependencies: #130227
1 parent a7715e3 commit ce4d951

File tree

5 files changed

+204
-160
lines changed

5 files changed

+204
-160
lines changed

test/inductor/test_flex_attention.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
index = torch.ops.aten.index
4848

4949

50+
def rmse(ref, res):
51+
"""
52+
Calculate root mean squared error
53+
"""
54+
return torch.sqrt(torch.mean(torch.square(ref - res)))
55+
56+
5057
def create_attention(score_mod, block_mask):
5158
return functools.partial(flex_attention, score_mod=score_mod, block_mask=block_mask)
5259

@@ -187,15 +194,15 @@ def _check_out_and_grad(
187194
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
188195

189196
# Check gradients
190-
q_fudge_factor = 2.5 * fudge_factor
197+
q_fudge_factor = 1.0 * fudge_factor
191198
self._check_equal(
192199
q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
193200
)
194-
k_fudge_factor = 4 * fudge_factor
201+
k_fudge_factor = 1.0 * fudge_factor
195202
self._check_equal(
196203
k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
197204
)
198-
v_fudge_factor = 4 * fudge_factor
205+
v_fudge_factor = 1.0 * fudge_factor
199206
self._check_equal(
200207
v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
201208
)
@@ -1058,6 +1065,7 @@ def sdpa_hop(q, k, v, score_mod, block_mask):
10581065
v,
10591066
score_mod,
10601067
block_mask.as_tuple(),
1068+
1.0,
10611069
)
10621070

10631071
@torch.compile(backend="aot_eager")
@@ -1066,13 +1074,7 @@ def eager_sdpa_hop(q, k, v, score_mod, block_mask):
10661074
Besides dropping LSE it also ensures that the hop is compiled with aot-eager
10671075
backend. We need to replicate this.
10681076
"""
1069-
return flex_attention_hop(
1070-
q,
1071-
k,
1072-
v,
1073-
score_mod,
1074-
block_mask.as_tuple(),
1075-
)
1077+
return flex_attention_hop(q, k, v, score_mod, block_mask.as_tuple(), 1.0)
10761078

10771079
ref_out, ref_lse = eager_sdpa_hop(
10781080
q.to(torch.float64),
@@ -1129,6 +1131,7 @@ def func(q, k, v, score_mod, block_mask):
11291131
v,
11301132
score_mod,
11311133
block_mask.as_tuple(),
1134+
scale=1.0,
11321135
)
11331136
lse_2 = lse * 2
11341137
return lse_2
@@ -1157,6 +1160,7 @@ def func(q, k, v, score_mod, block_mask):
11571160
v,
11581161
score_mod,
11591162
block_mask.as_tuple(),
1163+
1.0,
11601164
)
11611165
lse_2 = lse * 2
11621166
return out, lse_2
@@ -1211,6 +1215,49 @@ def test_captured_score_mod_aot_eager_gradcheck(
12111215
)
12121216
)
12131217

1218+
@supported_platform
1219+
def test_comparison_vs_sdpa(self):
1220+
inputs = [
1221+
torch.randn(
1222+
2, 2, 2048, 64, device="cuda", dtype=torch.float16, requires_grad=True
1223+
)
1224+
for _ in range(3)
1225+
]
1226+
gradOut = torch.randn(2, 2, 2048, 64, device="cuda", dtype=torch.float16)
1227+
out_ref = torch.nn.functional.scaled_dot_product_attention(
1228+
*inputs, is_causal=True
1229+
)
1230+
out_ref.backward(gradOut)
1231+
1232+
def causal(score, b, h, q_idx, kv_idx):
1233+
return torch.where(q_idx >= kv_idx, score, -float("inf"))
1234+
1235+
inputs_flex = [i.detach().clone().requires_grad_(True) for i in inputs]
1236+
out_flex = torch.compile(flex_attention)(*inputs_flex, causal)
1237+
out_flex.backward(gradOut)
1238+
inputs_golden = [
1239+
i.detach().clone().to(dtype=torch.float64).requires_grad_(True)
1240+
for i in inputs
1241+
]
1242+
out_golden = torch.nn.functional.scaled_dot_product_attention(
1243+
*inputs_golden, is_causal=True
1244+
)
1245+
out_golden.backward(gradOut.to(dtype=torch.float64))
1246+
1247+
for ref, flex, golden in [
1248+
(out_ref, out_flex, out_golden),
1249+
(inputs[0].grad, inputs_flex[0].grad, inputs_golden[0].grad),
1250+
(inputs[1].grad, inputs_flex[1].grad, inputs_golden[1].grad),
1251+
(inputs[2].grad, inputs_flex[2].grad, inputs_golden[2].grad),
1252+
]:
1253+
ref_error = rmse(ref, golden)
1254+
flex_error = rmse(flex, golden)
1255+
# Note: This has been carefully tested that FlexAttention is within
1256+
# 10% of the average error of SDPA! Do not bump this tolerance
1257+
# unless you are absolutely sure you are not worsening the accuracy
1258+
# of FlexAttention!
1259+
self.assertTrue(ref_error < flex_error * 1.1)
1260+
12141261
@supported_platform
12151262
def test_block_mask_attributes(self):
12161263
offset = torch.zeros(8, device="cuda")
@@ -1327,7 +1374,7 @@ def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_
13271374
child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
13281375
child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
13291376
flex_attention_0 = self.flex_attention_0
1330-
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8)); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
1377+
flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8), 0.5); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
13311378
out: "f64[2, 2, 8, 4]" = flex_attention[0]; flex_attention = None
13321379
return (out,)
13331380
@@ -1361,7 +1408,7 @@ class GraphModule(torch.nn.Module):
13611408
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
13621409
fw_graph = self.fw_graph
13631410
joint_graph = self.joint_graph
1364-
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8)); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
1411+
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8), 0.5); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
13651412
getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
13661413
getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
13671414
getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,7 @@ def call_function(
15991599
value,
16001600
score_mod,
16011601
block_mask,
1602+
scale,
16021603
) = self.normalize_to_args(args, kwargs)
16031604

16041605
p_args = self.create_wrapped_node(tx, query, score_mod)
@@ -1607,6 +1608,7 @@ def call_function(
16071608
key,
16081609
value,
16091610
block_mask,
1611+
scale,
16101612
]
16111613

16121614
# Store the invocation as a call
@@ -1624,7 +1626,7 @@ def call_function(
16241626
example_value = (out_meta, lse_meta)
16251627

16261628
# Compose the ordered HOO args from two parts:
1627-
# - inp_args: [query, key, value, block_mask]
1629+
# - inp_args: [query, key, value, block_mask, scale]
16281630
# - p_args: [score_mod, *other_buffers]
16291631
return wrap_fx_proxy(
16301632
tx=tx,

0 commit comments

Comments
 (0)