-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Closed
Copy link
Labels
Description
import torch
from torch import Tensor
from typing import *
import torch
@torch.library.custom_op("_reinplacing::add_one", mutates_args={"result"})
def add_one(x: torch.Tensor, result: torch.Tensor) -> None:
result.copy_(x + 1)
factory_op = torch.zeros_like
class AddOne(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
out = factory_op(x)
add_one(x, out)
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad):
saved, = ctx.saved_tensors
out = factory_op(grad)
add_one(saved, out)
return out
@torch.compile(backend="inductor")
def f(x):
return AddOne.apply(x)
x = torch.randn(3, requires_grad=True, device="cuda")
y = f(x)gives (with TORCH_LOGS=aot)
TRACED GRAPH
===== Forward graph 0 =====
/home/rzou/dev/debug-cpu1/pt-debug-cpu1/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3][1]cuda:0"):
# File: /home/rzou/dev/debug-cpu1/pt-debug-cpu1/foo.py:52 in f, code: return AddOne.apply(x)
full_default: "f32[3][1]cuda:0" = torch.ops.aten.full.default([3], 0, dtype = torch.float32, layout = torch.strided, device = device(
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops._reinplacing.add_one.default, x = primals_1, result = full
getitem_1: "f32[3][1]cuda:0" = auto_functionalized[1]; auto_functionalized = None
return (getitem_1, full_default, getitem_1)
TRACED GRAPH
===== Backward graph 0 =====
/home/rzou/dev/debug-cpu1/pt-debug-cpu1/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, full_default: "f32[3][1]cuda:0", getitem_1: "f32[3][1]cuda:0", tangents_1: "f32[3][1]cuda:0"):
# File: /home/rzou/dev/debug-cpu1/pt-debug-cpu1/foo.py:52 in f, code: return AddOne.apply(x)
auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops._reinplacing.add_one.default, x = getitem_1, result = fu
getitem_3: "f32[3][1]cuda:0" = auto_functionalized_1[1]; auto_functionalized_1 = None
return (getitem_3,)Interestingly using backend="aot_eager" doesn't have this problem