Skip to content

Partitioner doesn't recompute aten.full despite it being in the default_recomputable_ops #134468

@zou3519

Description

@zou3519
import torch
from torch import Tensor
from typing import *


import torch

@torch.library.custom_op("_reinplacing::add_one", mutates_args={"result"})
def add_one(x: torch.Tensor, result: torch.Tensor) -> None:
    result.copy_(x + 1)

factory_op = torch.zeros_like

class AddOne(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        out = factory_op(x)
        add_one(x, out)
        ctx.save_for_backward(out)
        return out
 
    @staticmethod
    def backward(ctx, grad):
        saved, = ctx.saved_tensors
        out = factory_op(grad)
        add_one(saved, out)
        return out

@torch.compile(backend="inductor")
def f(x):
    return AddOne.apply(x)

x = torch.randn(3, requires_grad=True, device="cuda")
y = f(x)

gives (with TORCH_LOGS=aot)

TRACED GRAPH
 ===== Forward graph 0 =====
 /home/rzou/dev/debug-cpu1/pt-debug-cpu1/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[3][1]cuda:0"):
         # File: /home/rzou/dev/debug-cpu1/pt-debug-cpu1/foo.py:52 in f, code: return AddOne.apply(x)
        full_default: "f32[3][1]cuda:0" = torch.ops.aten.full.default([3], 0, dtype = torch.float32, layout = torch.strided, device = device(

        auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops._reinplacing.add_one.default, x = primals_1, result = full

        getitem_1: "f32[3][1]cuda:0" = auto_functionalized[1];  auto_functionalized = None
        return (getitem_1, full_default, getitem_1)
        

TRACED GRAPH
 ===== Backward graph 0 =====
 /home/rzou/dev/debug-cpu1/pt-debug-cpu1/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, full_default: "f32[3][1]cuda:0", getitem_1: "f32[3][1]cuda:0", tangents_1: "f32[3][1]cuda:0"):
         # File: /home/rzou/dev/debug-cpu1/pt-debug-cpu1/foo.py:52 in f, code: return AddOne.apply(x)
        auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops._reinplacing.add_one.default, x = getitem_1, result = fu

        getitem_3: "f32[3][1]cuda:0" = auto_functionalized_1[1];  auto_functionalized_1 = None
        return (getitem_3,)

Interestingly using backend="aot_eager" doesn't have this problem

cc @ezyang @chauhang @penguinwu

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions