-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Repro below minified from a torchtune model, still investigating:
import torch
def forward(
s0: "",
s1: "",
L_x_: "",
L_self_modules_sa_norm_parameters_scale_: "",
L_self_modules_attn_modules_q_proj_parameters_weight_: "",
L_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_: "",
L_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_: "",
L_self_modules_attn_num_heads: "",
L_self_modules_attn_num_kv_heads: "",
L_self_modules_attn_head_dim: "",
L_self_modules_attn_modules_pos_embeddings_buffers_cache_: "",
L_self_modules_attn_modules_k_proj_parameters_weight_: "",
L_self_modules_attn_modules_v_proj_parameters_weight_: "",
L_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_: "",
L_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_: "",
L_self_modules_attn_modules_output_proj_parameters_weight_: "",
L_self_modules_attn_modules_output_proj_modules_lora_a_parameters_weight_: "",
L_self_modules_attn_modules_output_proj_modules_lora_b_parameters_weight_: "",
L_self_modules_mlp_norm_parameters_scale_: "",
L_self_modules_mlp_modules_w1_parameters_weight_: "",
L_self_modules_mlp_modules_w1_modules_lora_a_parameters_weight_: "",
L_self_modules_mlp_modules_w1_modules_lora_b_parameters_weight_: "",
L_self_modules_mlp_modules_w3_parameters_weight_: "",
L_self_modules_mlp_modules_w3_modules_lora_a_parameters_weight_: "",
L_self_modules_mlp_modules_w3_modules_lora_b_parameters_weight_: "",
L_self_modules_mlp_modules_w2_parameters_weight_: "",
L_self_modules_mlp_modules_w2_modules_lora_a_parameters_weight_: "",
L_self_modules_mlp_modules_w2_modules_lora_b_parameters_weight_: "",
dtype_=torch.bfloat16,
):
l_x_ = L_x_
l_self_modules_sa_norm_parameters_scale_ = L_self_modules_sa_norm_parameters_scale_
l_self_modules_attn_modules_q_proj_parameters_weight_ = L_self_modules_attn_modules_q_proj_parameters_weight_
l_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_ = L_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_
l_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_ = L_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_
l_self_modules_attn_num_heads = L_self_modules_attn_num_heads
l_self_modules_attn_num_kv_heads = L_self_modules_attn_num_kv_heads
l_self_modules_attn_head_dim = L_self_modules_attn_head_dim
l_self_modules_attn_modules_pos_embeddings_buffers_cache_ = L_self_modules_attn_modules_pos_embeddings_buffers_cache_
l_self_modules_attn_modules_k_proj_parameters_weight_ = L_self_modules_attn_modules_k_proj_parameters_weight_
l_self_modules_attn_modules_v_proj_parameters_weight_ = L_self_modules_attn_modules_v_proj_parameters_weight_
l_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_ = L_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_
l_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_ = L_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_
l_self_modules_attn_modules_output_proj_parameters_weight_ = L_self_modules_attn_modules_output_proj_parameters_weight_
l_self_modules_attn_modules_output_proj_modules_lora_a_parameters_weight_ = L_self_modules_attn_modules_output_proj_modules_lora_a_parameters_weight_
l_self_modules_attn_modules_output_proj_modules_lora_b_parameters_weight_ = L_self_modules_attn_modules_output_proj_modules_lora_b_parameters_weight_
l_self_modules_mlp_norm_parameters_scale_ = L_self_modules_mlp_norm_parameters_scale_
l_self_modules_mlp_modules_w1_parameters_weight_ = L_self_modules_mlp_modules_w1_parameters_weight_
l_self_modules_mlp_modules_w1_modules_lora_a_parameters_weight_ = L_self_modules_mlp_modules_w1_modules_lora_a_parameters_weight_
l_self_modules_mlp_modules_w1_modules_lora_b_parameters_weight_ = L_self_modules_mlp_modules_w1_modules_lora_b_parameters_weight_
l_self_modules_mlp_modules_w3_parameters_weight_ = L_self_modules_mlp_modules_w3_parameters_weight_
l_self_modules_mlp_modules_w3_modules_lora_a_parameters_weight_ = L_self_modules_mlp_modules_w3_modules_lora_a_parameters_weight_
l_self_modules_mlp_modules_w3_modules_lora_b_parameters_weight_ = L_self_modules_mlp_modules_w3_modules_lora_b_parameters_weight_
l_self_modules_mlp_modules_w2_parameters_weight_ = L_self_modules_mlp_modules_w2_parameters_weight_
l_self_modules_mlp_modules_w2_modules_lora_a_parameters_weight_ = L_self_modules_mlp_modules_w2_modules_lora_a_parameters_weight_
l_self_modules_mlp_modules_w2_modules_lora_b_parameters_weight_ = L_self_modules_mlp_modules_w2_modules_lora_b_parameters_weight_
# File: /home/hirsheybar/local/torchtune/torchtune/modules/rms_norm.py:39 in forward, code: x.float(),
float_1: "" = l_x_.float()
# File: /home/hirsheybar/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/functional.py:2929 in rms_norm, code: return torch.rms_norm(input, normalized_shape, weight, eps)
rms_norm: "" = torch.rms_norm(float_1, (4096,), l_self_modules_sa_norm_parameters_scale_, 1e-05)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/rms_norm.py:43 in forward, code: ).to(x.dtype)
h: "" = rms_norm.to(dtype_)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:126 in forward, code: out = F.linear(x, self.weight, self.bias)
out: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_q_proj_parameters_weight_, None)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:129 in forward, code: lora_out = self.lora_a(self.dropout(x))
lora_out: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_q_proj_modules_lora_a_parameters_weight_, None)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:130 in forward, code: lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
linear_2: "" = torch._C._nn.linear(lora_out, l_self_modules_attn_modules_q_proj_modules_lora_b_parameters_weight_, None)
lora_out_1: "" = 2.0 * linear_2
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:131 in forward, code: return out + lora_out
q: "" = out + lora_out_1
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:237 in forward, code: q_per_kv = self.num_heads // self.num_kv_heads
floordiv: "" = l_self_modules_attn_num_heads // l_self_modules_attn_num_kv_heads
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:238 in forward, code: q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
mul_1: "" = l_self_modules_attn_num_kv_heads * floordiv
q_1: "" = q.view(s0, s1, mul_1, l_self_modules_attn_head_dim)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:165 in forward, code: self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
rope_cache: "" = l_self_modules_attn_modules_pos_embeddings_buffers_cache_[slice(None, s1, None)]
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:171 in forward, code: xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
float_2: "" = q_1.float()
xshaped: "" = float_2.reshape(s0, s1, mul_1, -1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:176 in forward, code: rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
floordiv_2: "" = 2048 // mul_1
rope_cache_1: "" = rope_cache.view(-1, s1, 1, floordiv_2, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
getitem_11: "" = xshaped[(Ellipsis, 0)]
getitem_12: "" = rope_cache_1[(Ellipsis, 0)]
mul_2: "" = getitem_11 * getitem_12
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:182 in forward, code: - xshaped[..., 1] * rope_cache[..., 1],
getitem_13: "" = xshaped[(Ellipsis, 1)]
getitem_14: "" = rope_cache_1[(Ellipsis, 1)]
mul_3: "" = getitem_13 * getitem_14
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
sub: "" = mul_2 - mul_3
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
getitem_15: "" = xshaped[(Ellipsis, 1)]
getitem_16: "" = rope_cache_1[(Ellipsis, 0)]
mul_4: "" = getitem_15 * getitem_16
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:184 in forward, code: + xshaped[..., 0] * rope_cache[..., 1],
getitem_17: "" = xshaped[(Ellipsis, 0)]
getitem_18: "" = rope_cache_1[(Ellipsis, 1)]
mul_5: "" = getitem_17 * getitem_18
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
add_1: "" = mul_4 + mul_5
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:179 in forward, code: x_out = torch.stack(
x_out: "" = torch.stack([sub, add_1], -1)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:190 in forward, code: x_out = x_out.flatten(3)
x_out_1: "" = x_out.flatten(3)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:191 in forward, code: return x_out.type_as(x)
q_2: "" = x_out_1.type_as(q_1)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:245 in forward, code: q = q.transpose(1, 2)
q_3: "" = q_2.transpose(1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:262 in forward, code: k = self.k_proj(y)
k: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_k_proj_parameters_weight_, None)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:126 in forward, code: out = F.linear(x, self.weight, self.bias)
out_1: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_v_proj_parameters_weight_, None)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:129 in forward, code: lora_out = self.lora_a(self.dropout(x))
lora_out_2: "" = torch._C._nn.linear(h, l_self_modules_attn_modules_v_proj_modules_lora_a_parameters_weight_, None)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:130 in forward, code: lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
linear_6: "" = torch._C._nn.linear(lora_out_2, l_self_modules_attn_modules_v_proj_modules_lora_b_parameters_weight_, None)
lora_out_3: "" = 2.0 * linear_6
# File: /home/hirsheybar/local/torchtune/torchtune/modules/peft/lora.py:131 in forward, code: return out + lora_out
v: "" = out_1 + lora_out_3
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:267 in forward, code: k = k.view(b, s_y, -1, self.head_dim)
k_1: "" = k.view(s0, s1, -1, l_self_modules_attn_head_dim)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:268 in forward, code: v = v.view(b, s_y, -1, self.head_dim)
v_1: "" = v.view(s0, s1, -1, l_self_modules_attn_head_dim)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:165 in forward, code: self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
rope_cache_2: "" = l_self_modules_attn_modules_pos_embeddings_buffers_cache_[slice(None, s1, None)]
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:171 in forward, code: xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
float_3: "" = k_1.float()
floordiv_3: "" = 1024 // l_self_modules_attn_head_dim
xshaped_1: "" = float_3.reshape(s0, s1, floordiv_3, -1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:176 in forward, code: rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
floordiv_5: "" = 512 // floordiv_3
rope_cache_3: "" = rope_cache_2.view(-1, s1, 1, floordiv_5, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
getitem_24: "" = xshaped_1[(Ellipsis, 0)]
getitem_25: "" = rope_cache_3[(Ellipsis, 0)]
mul_7: "" = getitem_24 * getitem_25
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:182 in forward, code: - xshaped[..., 1] * rope_cache[..., 1],
getitem_26: "" = xshaped_1[(Ellipsis, 1)]
getitem_27: "" = rope_cache_3[(Ellipsis, 1)]
mul_8: "" = getitem_26 * getitem_27
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:181 in forward, code: xshaped[..., 0] * rope_cache[..., 0]
sub_1: "" = mul_7 - mul_8
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
getitem_28: "" = xshaped_1[(Ellipsis, 1)]
getitem_29: "" = rope_cache_3[(Ellipsis, 0)]
mul_9: "" = getitem_28 * getitem_29
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:184 in forward, code: + xshaped[..., 0] * rope_cache[..., 1],
getitem_30: "" = xshaped_1[(Ellipsis, 0)]
getitem_31: "" = rope_cache_3[(Ellipsis, 1)]
mul_10: "" = getitem_30 * getitem_31
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:183 in forward, code: xshaped[..., 1] * rope_cache[..., 0]
add_3: "" = mul_9 + mul_10
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:179 in forward, code: x_out = torch.stack(
x_out_2: "" = torch.stack([sub_1, add_3], -1)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:190 in forward, code: x_out = x_out.flatten(3)
x_out_3: "" = x_out_2.flatten(3)
# File: /home/hirsheybar/local/torchtune/torchtune/models/llama3_1/_position_embeddings.py:191 in forward, code: return x_out.type_as(x)
k_2: "" = x_out_3.type_as(k_1)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:273 in forward, code: k = k.transpose(1, 2)
k_3: "" = k_2.transpose(1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:274 in forward, code: v = v.transpose(1, 2)
v_2: "" = v_1.transpose(1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:289 in forward, code: k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
unsqueeze: "" = k_3.unsqueeze(2)
expand: "" = unsqueeze.expand((s0, l_self_modules_attn_num_kv_heads, floordiv, -1, l_self_modules_attn_head_dim))
k_4: "" = expand.flatten(1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention.py:290 in forward, code: v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
unsqueeze_1: "" = v_2.unsqueeze(2)
expand_1: "" = unsqueeze_1.expand((s0, l_self_modules_attn_num_kv_heads, floordiv, -1, l_self_modules_attn_head_dim))
v_3: "" = expand_1.flatten(1, 2)
# File: /home/hirsheybar/local/torchtune/torchtune/modules/attention_utils.py:216 in _attention_call, code: return nn.functional.scaled_dot_product_attention(
output: "" = torch._C._nn.scaled_dot_product_attention(q_3, k_4, v_3, attn_mask = None, dropout_p = 0.0, is_causal = True);
return output
from torch._dynamo.testing import rand_strided
primals_1 = 2
primals_2 = 157
primals_3 = rand_strided((2, 157, 4096), (643072, 4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_4 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
primals_5 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_6 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_7 = rand_strided((4096, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_8 = 32
primals_9 = 8
primals_10 = 128
primals_11 = rand_strided((131072, 64, 2), (128, 2, 1), device='cuda:0', dtype=torch.float32)
primals_12 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_13 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_14 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_15 = rand_strided((1024, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_16 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_17 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_18 = rand_strided((4096, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_19 = rand_strided((4096, ), (1, ), device='cuda:0', dtype=torch.bfloat16)
primals_20 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_21 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_22 = rand_strided((14336, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_23 = rand_strided((14336, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_24 = rand_strided((8, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
primals_25 = rand_strided((14336, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
primals_26 = rand_strided((4096, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
primals_27 = rand_strided((8, 14336), (14336, 1), device='cuda:0', dtype=torch.bfloat16)
primals_28 = rand_strided((4096, 8), (8, 1), device='cuda:0', dtype=torch.bfloat16)
#out2 = forward(*[primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28], torch.bfloat16)
out2 = torch.compile(forward, dynamic=True)(*[primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, primals_11, primals_12, primals_13, primals_14, primals_15, primals_16, primals_17, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_27, primals_28], torch.bfloat16)
print(torch.any(torch.isnan(out2)))
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93
Metadata
Metadata
Assignees
Labels
high prioritymodule: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module