Skip to content

Inductor accuracy failures for Weight Norm #140452

@bonpyt

Description

@bonpyt

🐛 Describe the bug

When compiling a weight-normed linear layer, some shapes are causing accuracy failures.
The backends eager and aot_eager seem to work, but inductor causes issues.
For example, the code below passes with in_features=1024, but fails with in_features=1025.
This leads to severe regressions when training models that use Weight Norm.

This fails on multiple Pytorch versions for Python 3.12 (tested: 2.4.1, 2.5, 2.6 current nightly), but works on 2.3 for Python 3.8.

import sys
from functools import partial
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd

import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
torch._dynamo.config.optimize_ddp = False

from typing import Dict, Optional
import torch
from torch.nn import *


from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.weight_normed_linear = torch.nn.utils.parametrizations.weight_norm(torch.nn.Linear(in_features=in_features, out_features=2)).cuda()
        self.linear = torch.nn.Linear(in_features=2, out_features=1).cuda()

    def forward(self, x_0):
        x_1 = self.weight_normed_linear(x_0)
        x_2 = self.linear(x_1)
        return (x_2,)

def load_args(in_features, reader):  
    buf0 = reader.storage('fbae9e314f27f66ab2f21026f411a176d6711e51', 9043968, device=device(type='cuda', index=0), dtype_hint=torch.float16)
    reader.tensor(buf0, (2, in_features), dtype=torch.float16, requires_grad=True) 

if __name__ == '__main__':

    for in_features in [1024, 1025]:
      torch.compiler.reset()

      mod = Repro(in_features)
      load_args_partial = partial(load_args, in_features)
      load_args_partial.version = 0

      from torch._dynamo.repro.after_dynamo import run_repro
      run_repro(mod,load_args_partial, accuracy=True, command='run',
              save_dir='/stuff/felixb/kernel_workspace/checkpoints', autocast=True, backend='inductor')

Versions

This passes on 2.3.1 (with Python 3.8), but starts breaking in 2.4 (with Python 3.12) and doesn't seem to be fixed on the current nightly.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @BoyuanFeng

Metadata

Metadata

Assignees

Labels

high prioritymodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions