-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Description
The meta dtypes of bmm are different for inference and training. The dtype is implicit for inference, while it is a wrong explicit dtype for training.
Reproduction
import torch
from torch import nn
import torch.nn.functional as F
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.all_head_size = 12 * 64
self.dense = nn.Linear(self.all_head_size, self.all_head_size)
def forward(self, q, k, v):
context_layer = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.2
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
return self.dense(context_layer)
if __name__ == "__main__":
mod = M().to(torch.bfloat16).eval()
q = torch.randn((28, 12, 512, 64), dtype=torch.bfloat16)
k = torch.randn((28, 12, 512, 64), dtype=torch.bfloat16)
v = torch.randn((28, 12, 512, 64), dtype=torch.bfloat16)
inputs = (q, k, v,)
is_inference = 0
if is_inference:
# 1. INFERENCE: run successfully
with torch.no_grad(), torch.cpu.amp.autocast():
compiler_mode = torch.compile(mod)
_ = compiler_mode(*inputs)
output = compiler_mode(*inputs)
else:
# 2. TRAINING: runtime error: mat1 and mat2 must have the same dtype, but got Float and BFloat16
with torch.cpu.amp.autocast():
compiler_mode = torch.compile(mod)
_ = compiler_mode(*inputs)
output = compiler_mode(*inputs)
The runtime error would be encountered for training:
File "/tmp/torchinductor_liaoxuan/nd/cndd74gsb7w2znlolklu4k6iyinqhsv62jwrnbwjwdlfprf25fem.py", line 273, in call
extern_kernels.addmm(primals_5, reinterpret_tensor(buf12, (14336, 768), (768, 1), 0), reinterpret_tensor(primals_4, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf13)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
Analysis
The root cause is the wrong meta dtype of bmm (the first bmm in SDPA). During the FakeTensorProp, the dtypes in bmm meta are different for inference and training. Actually, the training dtype is wrong, and is expected to be fp32.
- Inference: FakeTensor(..., size=(336, 512, 512))
- Training: FakeTensor(..., size=(336, 512, 512), dtype=torch.bfloat16)
With further investigation, we find that: during inference, it goes into aot_dispatch_base where the autocast turns off; during training, it goes into aot_dispatch_autograd where the autocast turns on, and triggers the bug. If the autocast also turns off during training, the issue can be resolved. I am wondering if this is by design or there is any other reason to have different behaviors for inference and training.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @ezyang @leslie-fang-intel