Skip to content

custom ops don't reinplace when mutated arg is a view of a graph input #131192

@zou3519

Description

@zou3519
import torch
from torch import Tensor # E: invalid syntax  [syntax]

@torch.library.custom_op("mylib::foo", mutates_args={"x"})
def foo(x: Tensor) -> None:
    x.sin_()

@torch.compile(fullgraph=True)
def f(x):
    x0 = x[0]
    foo(x0)

x = torch.randn(2)
f(x)
"""
# produces the following
def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (2, ), (1, ))
    buf0 = empty_strided_cpu((1, ), (1, ), torch.float32)
    cpp_fused_0(arg0_1, buf0)
    # Source Nodes: [], Original ATen: []
    torch.ops.mylib.foo.default(reinterpret_tensor(buf0, (), (), 0))
    cpp_fused_1(buf0, arg0_1)
    del arg0_1
    return ()
"""

changing the "foo" to x0.sin_() does lead to inductor reinplacing the code:

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (2, ), (1, ))
    cpp_fused_sin_0(arg0_1, arg0_1)
    del arg0_1
    return ()

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @coconutruben @bdhirsh @ezyang @anijain2305 @yf225 @ColinPeppler @desertfire

Metadata

Metadata

Assignees

Labels

module: custom-operatorscustom operators, custom ops, custom-operators, custom-opsmodule: inductormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,module: reinplacinginductor reinplacing, re-inplacing, auto-functionalization, auto functionalization, custom opmodule: vllmoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate modulevllm-compile

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions