-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: aotdispatchumbrella label for AOTAutograd issuesumbrella label for AOTAutograd issuesmodule: functionalizationused for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,module: vmaponcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
This is what I am trying to do:
- Create an input and target tensor
- Create a list of models for ensemble
- Use vmap to create an ensemble
- Obtain the vmap function and pass it to aot_function, to obtain the forward and backward graphs
Repro:
from typing import List
import torch
from torch._subclasses import FakeTensorMode, FakeTensor
from functorch.compile import aot_function, print_compile, config, aot_module
from functorch import make_functional_with_buffers, vmap, combine_state_for_ensemble
from functorch._src.named_members_polyfill import _named_parameters, _named_buffers
from torchvision.models import resnet18
g = {}
def fake_wrapper(gtype):
def fake_compiler(fx_g, inps):
print(fx_g.code)
nonlocal gtype
g[gtype] = fx_g
return fx_g
return fake_compiler
inp = torch.randn(32, 3, 224, 224, dtype=torch.float32).cuda()
targets = torch.zeros(32, dtype=int).cuda()
b_models:List[torch.nn.Module] = [resnet18().cuda() for _ in range(5)]
func_model, params, buffers = combine_state_for_ensemble(b_models)
for p in params:
p.requires_grad = True
def compute_loss(weights, buffers, batch, targets):
output = func_model(weights, buffers, batch)
loss = torch.nn.functional.nll_loss(output,targets)
return loss
parallel_func = vmap(compute_loss, in_dims=(0,0,None, None))
aot_func = aot_function(parallel_func, fake_wrapper("forward"), fake_wrapper("backward"))
out = aot_func(params, buffers, inp, targets)
out.mean().backward()
Error:
RuntimeError: !self.requires_grad() || self.is_contiguous() INTERNAL ASSERT FAILED at “/scratch/anijain/work/pytorch/aten/src/ATen/native/TensorShape.cpp”:3609, please report a bug to PyTorch. as_strided_scatter is currently only supported for contiguous inputs While executing %new_empty_strided_1 : [#users=1] = call_function[target=torch.ops.aten.new_empty_strided.default](args = (%copy__1, [5, 512, 32, 7, 7], [25088, 49, 125440, 7, 1]), kwargs = {}) Original traceback: Gradient addition node due to multiple use of tensor around:
cc @zou3519 @bdhirsh @ezyang @msaroufim @wconstab @anijain2305 @soumith
Metadata
Metadata
Assignees
Labels
module: aotdispatchumbrella label for AOTAutograd issuesumbrella label for AOTAutograd issuesmodule: functionalizationused for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)used for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,module: vmaponcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module