Skip to content

[inductor] [cpu] [silent] avg_pool2d incorrectly process int64 #143738

@shaoyuyoung

Description

@shaoyuyoung

🐛 Describe the bug

I think this is related to #143729 but the symptom is different.
in #143729, CPU inductor raises compileError but this time, avg_pool2d outputs a silent incorrectness.
Should this be a hig-pri?

BTW, cuda would reject the Long dtype.

exposed area: avg_pool1d, avg_pool2d and avg_pool3d

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
from torch._inductor import config

config.fallback_random = True


class Model(torch.nn.Module):

    def __init__(self):
        super(Model, self).__init__()

    def forward(self, x):
        torch.manual_seed(0)
        x = torch.argsort(x, dim=3)
        # x.dtype: torch.int64
        x = F.avg_pool2d(x, kernel_size=2, stride=2)
        return x


model = Model()


x = torch.randn(1, 1, 2, 4)

inputs = [x]

output = model(*inputs)

c_model = torch.compile(model)
c_output = c_model(*inputs)

print(output)
print(c_output)

Error logs

tensor([[[[1, 2]]]])
tensor([[[[0, 0]]]])

Versions

Exactly the same as #143729

cc @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: cpu inductorCPU Inductor issues for Intel team to triageoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions