Skip to content

The diag should be dispatched to core IR for matrix input #117349

@lmuzalewski

Description

@lmuzalewski

🐛 Describe the bug

When trying to execute diag on matrix input it is currently dispatched to diagonal_copy which is not core IR.

import torch
from torch._dynamo.backends.common import aot_autograd
from torch._decomp import core_aten_decompositions

def inner_compiler(fx_module: torch.fx.GraphModule, example_inputs):
    print(fx_module.code)
    return fx_module

aot_backend = aot_autograd(fw_compiler=inner_compiler, decompositions=core_aten_decompositions())

def fn(i1, i2):
    return torch.diag(i1, i2)
torch._dynamo.reset()
c = torch.compile(fn, backend=aot_backend)
x = torch.rand((8,8), dtype=torch.float)
y = 0
c(x, y)

The resut is:

def forward(self, arg0_1):
    diagonal_copy = torch.ops.aten.diagonal_copy.default(arg0_1);  arg0_1 = None
    return (diagonal_copy,)

Error logs

No response

Minified repro

No response

Versions

I executed this on Colab
torch.version = 2.1.0+cu121

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

No one assigned

    Labels

    export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions