Skip to content

[Inductor][Training] Different meta dtype for inference and training #137762

@Valentine233

Description

@Valentine233

Description

The meta dtypes of bmm are different for inference and training. The dtype is implicit for inference, while it is a wrong explicit dtype for training.

Reproduction

import torch
from torch import nn
import torch.nn.functional as F

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.all_head_size = 12 * 64
        self.dense = nn.Linear(self.all_head_size, self.all_head_size)

    def forward(self, q, k, v):
        context_layer = F.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=0.2
        )
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.dense(context_layer)
 
 
if __name__ == "__main__":
    mod = M().to(torch.bfloat16).eval()
 
    q = torch.randn((28, 12, 512, 64), dtype=torch.bfloat16)
    k = torch.randn((28, 12, 512, 64), dtype=torch.bfloat16)
    v = torch.randn((28, 12, 512, 64), dtype=torch.bfloat16)
    inputs = (q, k, v,)

    is_inference = 0

    if is_inference:
        # 1. INFERENCE: run successfully
        with torch.no_grad(), torch.cpu.amp.autocast():
            compiler_mode = torch.compile(mod)
            _ = compiler_mode(*inputs)
            output = compiler_mode(*inputs)
    else:
        # 2. TRAINING: runtime error: mat1 and mat2 must have the same dtype, but got Float and BFloat16
        with torch.cpu.amp.autocast():
            compiler_mode = torch.compile(mod)
            _ = compiler_mode(*inputs)
            output = compiler_mode(*inputs)

The runtime error would be encountered for training:

File "/tmp/torchinductor_liaoxuan/nd/cndd74gsb7w2znlolklu4k6iyinqhsv62jwrnbwjwdlfprf25fem.py", line 273, in call
    extern_kernels.addmm(primals_5, reinterpret_tensor(buf12, (14336, 768), (768, 1), 0), reinterpret_tensor(primals_4, (768, 768), (1, 768), 0), alpha=1, beta=1, out=buf13)
RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16

Analysis

The root cause is the wrong meta dtype of bmm (the first bmm in SDPA). During the FakeTensorProp, the dtypes in bmm meta are different for inference and training. Actually, the training dtype is wrong, and is expected to be fp32.

  • Inference: FakeTensor(..., size=(336, 512, 512))
  • Training: FakeTensor(..., size=(336, 512, 512), dtype=torch.bfloat16)

With further investigation, we find that: during inference, it goes into aot_dispatch_base where the autocast turns off; during training, it goes into aot_dispatch_autograd where the autocast turns on, and triggers the bug. If the autocast also turns off during training, the issue can be resolved. I am wondering if this is by design or there is any other reason to have different behaviors for inference and training.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @ezyang @leslie-fang-intel

Metadata

Metadata

Assignees

Labels

module: inductoroncall: cpu inductorCPU Inductor issues for Intel team to triagetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions