4747index = torch .ops .aten .index
4848
4949
50+ def rmse (ref , res ):
51+ """
52+ Calculate root mean squared error
53+ """
54+ return torch .sqrt (torch .mean (torch .square (ref - res )))
55+
56+
5057def create_attention (score_mod , block_mask ):
5158 return functools .partial (flex_attention , score_mod = score_mod , block_mask = block_mask )
5259
@@ -187,15 +194,15 @@ def _check_out_and_grad(
187194 self ._check_equal (golden_out , ref_out , compiled_out , fudge_factor , "Out" )
188195
189196 # Check gradients
190- q_fudge_factor = 2.5 * fudge_factor
197+ q_fudge_factor = 1.0 * fudge_factor
191198 self ._check_equal (
192199 q_gold .grad , q_ref .grad , q .grad , q_fudge_factor , "Grad_Query"
193200 )
194- k_fudge_factor = 4 * fudge_factor
201+ k_fudge_factor = 1.0 * fudge_factor
195202 self ._check_equal (
196203 k_gold .grad , k_ref .grad , k .grad , k_fudge_factor , "Grad_Key"
197204 )
198- v_fudge_factor = 4 * fudge_factor
205+ v_fudge_factor = 1.0 * fudge_factor
199206 self ._check_equal (
200207 v_gold .grad , v_ref .grad , v .grad , v_fudge_factor , "Grad_Value"
201208 )
@@ -1058,6 +1065,7 @@ def sdpa_hop(q, k, v, score_mod, block_mask):
10581065 v ,
10591066 score_mod ,
10601067 block_mask .as_tuple (),
1068+ 1.0 ,
10611069 )
10621070
10631071 @torch .compile (backend = "aot_eager" )
@@ -1066,13 +1074,7 @@ def eager_sdpa_hop(q, k, v, score_mod, block_mask):
10661074 Besides dropping LSE it also ensures that the hop is compiled with aot-eager
10671075 backend. We need to replicate this.
10681076 """
1069- return flex_attention_hop (
1070- q ,
1071- k ,
1072- v ,
1073- score_mod ,
1074- block_mask .as_tuple (),
1075- )
1077+ return flex_attention_hop (q , k , v , score_mod , block_mask .as_tuple (), 1.0 )
10761078
10771079 ref_out , ref_lse = eager_sdpa_hop (
10781080 q .to (torch .float64 ),
@@ -1129,6 +1131,7 @@ def func(q, k, v, score_mod, block_mask):
11291131 v ,
11301132 score_mod ,
11311133 block_mask .as_tuple (),
1134+ scale = 1.0 ,
11321135 )
11331136 lse_2 = lse * 2
11341137 return lse_2
@@ -1157,6 +1160,7 @@ def func(q, k, v, score_mod, block_mask):
11571160 v ,
11581161 score_mod ,
11591162 block_mask .as_tuple (),
1163+ 1.0 ,
11601164 )
11611165 lse_2 = lse * 2
11621166 return out , lse_2
@@ -1211,6 +1215,49 @@ def test_captured_score_mod_aot_eager_gradcheck(
12111215 )
12121216 )
12131217
1218+ @supported_platform
1219+ def test_comparison_vs_sdpa (self ):
1220+ inputs = [
1221+ torch .randn (
1222+ 2 , 2 , 2048 , 64 , device = "cuda" , dtype = torch .float16 , requires_grad = True
1223+ )
1224+ for _ in range (3 )
1225+ ]
1226+ gradOut = torch .randn (2 , 2 , 2048 , 64 , device = "cuda" , dtype = torch .float16 )
1227+ out_ref = torch .nn .functional .scaled_dot_product_attention (
1228+ * inputs , is_causal = True
1229+ )
1230+ out_ref .backward (gradOut )
1231+
1232+ def causal (score , b , h , q_idx , kv_idx ):
1233+ return torch .where (q_idx >= kv_idx , score , - float ("inf" ))
1234+
1235+ inputs_flex = [i .detach ().clone ().requires_grad_ (True ) for i in inputs ]
1236+ out_flex = torch .compile (flex_attention )(* inputs_flex , causal )
1237+ out_flex .backward (gradOut )
1238+ inputs_golden = [
1239+ i .detach ().clone ().to (dtype = torch .float64 ).requires_grad_ (True )
1240+ for i in inputs
1241+ ]
1242+ out_golden = torch .nn .functional .scaled_dot_product_attention (
1243+ * inputs_golden , is_causal = True
1244+ )
1245+ out_golden .backward (gradOut .to (dtype = torch .float64 ))
1246+
1247+ for ref , flex , golden in [
1248+ (out_ref , out_flex , out_golden ),
1249+ (inputs [0 ].grad , inputs_flex [0 ].grad , inputs_golden [0 ].grad ),
1250+ (inputs [1 ].grad , inputs_flex [1 ].grad , inputs_golden [1 ].grad ),
1251+ (inputs [2 ].grad , inputs_flex [2 ].grad , inputs_golden [2 ].grad ),
1252+ ]:
1253+ ref_error = rmse (ref , golden )
1254+ flex_error = rmse (flex , golden )
1255+ # Note: This has been carefully tested that FlexAttention is within
1256+ # 10% of the average error of SDPA! Do not bump this tolerance
1257+ # unless you are absolutely sure you are not worsening the accuracy
1258+ # of FlexAttention!
1259+ self .assertTrue (ref_error < flex_error * 1.1 )
1260+
12141261 @supported_platform
12151262 def test_block_mask_attributes (self ):
12161263 offset = torch .zeros (8 , device = "cuda" )
@@ -1327,7 +1374,7 @@ def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_
13271374 child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
13281375 child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
13291376 flex_attention_0 = self.flex_attention_0
1330- flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8)); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
1377+ flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8), 0.5 ); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
13311378 out: "f64[2, 2, 8, 4]" = flex_attention[0]; flex_attention = None
13321379 return (out,)
13331380
@@ -1361,7 +1408,7 @@ class GraphModule(torch.nn.Module):
13611408 def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 8, 4]", getitem_1: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
13621409 fw_graph = self.fw_graph
13631410 joint_graph = self.joint_graph
1364- flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8)); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
1411+ flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8), 0.5 ); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
13651412 getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
13661413 getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
13671414 getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None
0 commit comments