Skip to content

compiling attention layer with dynamic shapes yields nans #141710

@bdhirsh

Description

@bdhirsh

Repro below minified from a torchtune model, still investigating:

import torch


def forward(
    s0: "",
    s1: "",
    L_x_: "",
    L_self_modules_sa_norm_parameters_scale_: "",
    L_self_modules_attn_modules_q_proj_parameters_weight_: "",
    L_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_: "",
    L_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_: "",
    L_self_modules_attn_num_heads: "",
    L_self_modules_attn_num_kv_heads: "",
    L_self_modules_attn_head_dim: "",
    L_self_modules_attn_modules_pos_embeddings_buffers_cache_: "",
    L_self_modules_attn_modules_k_proj_parameters_weight_: "",
    L_self_modules_attn_modules_v_proj_parameters_weight_: "",
    L_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_: "",
    L_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_: "",
    L_self_modules_attn_modules_output_proj_parameters_weight_: "",
    L_self_modules_attn_modules_output_proj_modules_lora_a_parameters_weight_: "",
    L_self_modules_attn_modules_output_proj_modules_lora_b_parameters_weight_: "",
    L_self_modules_mlp_norm_parameters_scale_: "",
    L_self_modules_mlp_modules_w1_parameters_weight_: "",
    L_self_modules_mlp_modules_w1_modules_lora_a_parameters_weight_: "",
    L_self_modules_mlp_modules_w1_modules_lora_b_parameters_weight_: "",
    L_self_modules_mlp_modules_w3_parameters_weight_: "",
    L_self_modules_mlp_modules_w3_modules_lora_a_parameters_weight_: "",
    L_self_modules_mlp_modules_w3_modules_lora_b_parameters_weight_: "",
    L_self_modules_mlp_modules_w2_parameters_weight_: "",
    L_self_modules_mlp_modules_w2_modules_lora_a_parameters_weight_: "",
    L_self_modules_mlp_modules_w2_modules_lora_b_parameters_weight_: "",
    dtype_=torch.bfloat16,
):
    l_x_ = L_x_
    l_self_modules_sa_norm_parameters_scale_ = L_self_modules_sa_norm_parameters_scale_
    l_self_modules_attn_modules_q_proj_parameters_weight_ = L_self_modules_attn_modules_q_proj_parameters_weight_
    l_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_ = L_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_
    l_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_ = L_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_
    l_self_modules_attn_num_heads = L_self_modules_attn_num_heads
    l_self_modules_attn_num_kv_heads = L_self_modules_attn_num_kv_heads
    l_self_modules_attn_head_dim = L_self_modules_attn_head_dim
    l_self_modules_attn_modules_pos_embeddings_buffers_cache_ = L_self_modules_attn_modules_pos_embeddings_buffers_cache_
    l_self_modules_attn_modules_k_proj_parameters_weight_ = L_self_modules_attn_modules_k_proj_parameters_weight_
    l_self_modules_attn_modules_v_proj_parameters_weight_ = L_self_modules_attn_modules_v_proj_parameters_weight_
    l_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_ = L_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_
    l_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_ = L_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_
    l_self_modules_attn_modules_output_proj_parameters_weight_ = L_self_modules_attn_modules_output_proj_parameters_weight_
    l_self_modules_attn_modules_output_proj_modules_lora_a_parameters_weight_ = L_self_modules_attn_modules_output_proj_modules_lora_a_parameters_weight_
    l_self_modules_attn_modules_output_proj_modules_lora_b_parameters_weight_ = L_self_modules_attn_modules_output_proj_modules_lora_b_parameters_weight_
    l_self_modules_mlp_norm_parameters_scale_ = L_self_modules_mlp_norm_parameters_scale_
    l_self_modules_mlp_modules_w1_parameters_weight_ = L_self_modules_mlp_modules_w1_parameters_weight_
    l_self_modules_mlp_modules_w1_modules_lora_a_parameters_weight_ = L_self_modules_mlp_modules_w1_modules_lora_a_parameters_weight_
    l_self_modules_mlp_modules_w1_modules_lora_b_parameters_weight_ = L_self_modules_mlp_modules_w1_modules_lora_b_parameters_weight_
    l_self_modules_mlp_modules_w3_parameters_weight_ = L_self_modules_mlp_modules_w3_parameters_weight_
    l_self_modules_mlp_modules_w3_modules_lora_a_parameters_weight_ = L_self_modules_mlp_modules_w3_modules_lora_a_parameters_weight_
    l_self_modules_mlp_modules_w3_modules_lora_b_parameters_weight_ = L_self_modules_mlp_modules_w3_modules_lora_b_parameters_weight_
    l_self_modules_mlp_modules_w2_parameters_weight_ = L_self_modules_mlp_modules_w2_parameters_weight_
    l_self_modules_mlp_modules_w2_modules_lora_a_parameters_weight_ = L_self_modules_mlp_modules_w2_modules_lora_a_parameters_weight_
    l_self_modules_mlp_modules_w2_modules_lora_b_parameters_weight_ = L_self_modules_mlp_modules_w2_modules_lora_b_parameters_weight_

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/rms_norm.py:39 in forward, code: x.float(),
    float_1: "" = l_x_.float()

     # File: /home/hirsheybar/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/functional.py:2929 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
    rms_norm: "" = torch.rms_norm(float_1, (4096,), l_self_modules_sa_norm_parameters_scale_, 1e-05)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/rms_norm.py:43 in forward, code: ).to(x.dtype)
    h: "" = rms_norm.to(dtype_)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:126 in forward, code: out = F.linear(x, self.weight, self.bias)
    out: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_q_proj_parameters_weight_, None)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:129 in forward, code: lora_out = self.lora_a(self.dropout(x))
    lora_out: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_, None)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:130 in forward, code: lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
    linear_2: "" = torch._C._nn.linear(lora_out, l_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_, None)
    lora_out_1: "" = 2.0 * linear_2

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:131 in forward, code: return out + lora_out
    q: "" = out + lora_out_1

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:237 in forward, code: q_per_kv = self.num_heads // self.num_kv_heads
    floordiv: "" = l_self_modules_attn_num_heads // l_self_modules_attn_num_kv_heads

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:238 in forward, code: q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
    mul_1: "" = l_self_modules_attn_num_kv_heads * floordiv
    q_1: "" = q.view(s0, s1, mul_1, l_self_modules_attn_head_dim)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:165 in forward, code: self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
    rope_cache: "" = l_self_modules_attn_modules_pos_embeddings_buffers_cache_[slice(None, s1, None)]

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:171 in forward, code: xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    float_2: "" = q_1.float()
    xshaped: "" = float_2.reshape(s0, s1, mul_1, -1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:176 in forward, code: rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
    floordiv_2: "" = 2048 // mul_1
    rope_cache_1: "" = rope_cache.view(-1, s1, 1, floordiv_2, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
    getitem_11: "" = xshaped[(Ellipsis, 0)]
    getitem_12: "" = rope_cache_1[(Ellipsis, 0)]
    mul_2: "" = getitem_11 * getitem_12

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:182 in forward, code: - xshaped[..., 1] * rope_cache[..., 1],
    getitem_13: "" = xshaped[(Ellipsis, 1)]
    getitem_14: "" = rope_cache_1[(Ellipsis, 1)]
    mul_3: "" = getitem_13 * getitem_14

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
    sub: "" = mul_2 - mul_3

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
    getitem_15: "" = xshaped[(Ellipsis, 1)]
    getitem_16: "" = rope_cache_1[(Ellipsis, 0)]
    mul_4: "" = getitem_15 * getitem_16

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:184 in forward, code: + xshaped[..., 0] * rope_cache[..., 1],
    getitem_17: "" = xshaped[(Ellipsis, 0)]
    getitem_18: "" = rope_cache_1[(Ellipsis, 1)]
    mul_5: "" = getitem_17 * getitem_18

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
    add_1: "" = mul_4 + mul_5

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:179 in forward, code: x_out = torch.stack(
    x_out: "" = torch.stack([sub, add_1], -1)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:190 in forward, code: x_out = x_out.flatten(3)
    x_out_1: "" = x_out.flatten(3)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:191 in forward, code: return x_out.type_as(x)
    q_2: "" = x_out_1.type_as(q_1)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:245 in forward, code: q = q.transpose(1, 2)
    q_3: "" = q_2.transpose(1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:262 in forward, code: k = self.k_proj(y)
    k: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_k_proj_parameters_weight_, None)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:126 in forward, code: out = F.linear(x, self.weight, self.bias)
    out_1: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_v_proj_parameters_weight_, None)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:129 in forward, code: lora_out = self.lora_a(self.dropout(x))
    lora_out_2: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_, None)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:130 in forward, code: lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
    linear_6: "" = torch._C._nn.linear(lora_out_2, l_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_, None)
    lora_out_3: "" = 2.0 * linear_6

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:131 in forward, code: return out + lora_out
    v: "" = out_1 + lora_out_3

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:267 in forward, code: k = k.view(b, s_y, -1, self.head_dim)
    k_1: "" = k.view(s0, s1, -1, l_self_modules_attn_head_dim)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:268 in forward, code: v = v.view(b, s_y, -1, self.head_dim)
    v_1: "" = v.view(s0, s1, -1, l_self_modules_attn_head_dim)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:165 in forward, code: self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
    rope_cache_2: "" = l_self_modules_attn_modules_pos_embeddings_buffers_cache_[slice(None, s1, None)]

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:171 in forward, code: xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    float_3: "" = k_1.float()
    floordiv_3: "" = 1024 // l_self_modules_attn_head_dim
    xshaped_1: "" = float_3.reshape(s0, s1, floordiv_3, -1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:176 in forward, code: rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
    floordiv_5: "" = 512 // floordiv_3
    rope_cache_3: "" = rope_cache_2.view(-1, s1, 1, floordiv_5, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
    getitem_24: "" = xshaped_1[(Ellipsis, 0)]
    getitem_25: "" = rope_cache_3[(Ellipsis, 0)]
    mul_7: "" = getitem_24 * getitem_25

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:182 in forward, code: - xshaped[..., 1] * rope_cache[..., 1],
    getitem_26: "" = xshaped_1[(Ellipsis, 1)]
    getitem_27: "" = rope_cache_3[(Ellipsis, 1)]
    mul_8: "" = getitem_26 * getitem_27

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
    sub_1: "" = mul_7 - mul_8

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
    getitem_28: "" = xshaped_1[(Ellipsis, 1)]
    getitem_29: "" = rope_cache_3[(Ellipsis, 0)]
    mul_9: "" = getitem_28 * getitem_29

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:184 in forward, code: + xshaped[..., 0] * rope_cache[..., 1],
    getitem_30: "" = xshaped_1[(Ellipsis, 0)]
    getitem_31: "" = rope_cache_3[(Ellipsis, 1)]
    mul_10: "" = getitem_30 * getitem_31

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
    add_3: "" = mul_9 + mul_10

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:179 in forward, code: x_out = torch.stack(
    x_out_2: "" = torch.stack([sub_1, add_3], -1)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:190 in forward, code: x_out = x_out.flatten(3)
    x_out_3: "" = x_out_2.flatten(3)

     # File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:191 in forward, code: return x_out.type_as(x)
    k_2: "" = x_out_3.type_as(k_1)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:273 in forward, code: k = k.transpose(1, 2)
    k_3: "" = k_2.transpose(1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:274 in forward, code: v = v.transpose(1, 2)
    v_2: "" = v_1.transpose(1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:289 in forward, code: k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
    unsqueeze: "" = k_3.unsqueeze(2)
    expand: "" = unsqueeze.expand((s0, l_self_modules_attn_num_kv_heads, floordiv, -1, l_self_modules_attn_head_dim))
    k_4: "" = expand.flatten(1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:290 in forward, code: v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
    unsqueeze_1: "" = v_2.unsqueeze(2)
    expand_1: "" = unsqueeze_1.expand((s0, l_self_modules_attn_num_kv_heads, floordiv, -1, l_self_modules_attn_head_dim))
    v_3: "" = expand_1.flatten(1, 2)

     # File: /home/hirsheybar/local/torchtune/torchtune/modules/attention_utils.py:216 in _attention_call, code: return nn.functional.scaled_dot_product_attention(
    output: "" = torch._C._nn.scaled_dot_product_attention(q_3, k_4, v_3, attn_mask = None, dropout_p = 0.0, is_causal = True);
    return output

from torch._dynamo.testing import rand_strided
primals_1 = 2
primals_2 = 157
primals_3 = rand_strided((2, 157, 4096), (643072, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_4 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
primals_5 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_6 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_7 = rand_strided((4096, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_8 = 32
primals_9 = 8
primals_10 = 128
primals_11 = rand_strided((131072, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
primals_12 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_13 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_14 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_15 = rand_strided((1024, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_16 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_17 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_18 = rand_strided((4096, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_19 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
primals_20 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_21 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_22 = rand_strided((14336, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_23 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_24 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_25 = rand_strided((14336, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_26 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
primals_27 = rand_strided((8, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
primals_28 = rand_strided((4096, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)

#out2 = forward(*[primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28], torch.bfloat16)
out2 = torch.compile(forward, dynamic=True)(*[primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28], torch.bfloat16)
print(torch.any(torch.isnan(out2)))

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions