Skip to content

[functorch] [aot_autograd]  #86427

@sanketpurandare

Description

@sanketpurandare

This is what I am trying to do:

  1. Create an input and target tensor
  2. Create a list of models for ensemble
  3. Use vmap to create an ensemble
  4. Obtain the vmap function and pass it to aot_function, to obtain the forward and backward graphs

Repro:

from typing import List
import torch
from torch._subclasses import FakeTensorMode, FakeTensor
from functorch.compile import aot_function, print_compile, config, aot_module
from functorch import make_functional_with_buffers, vmap, combine_state_for_ensemble
from functorch._src.named_members_polyfill import _named_parameters, _named_buffers
from torchvision.models import resnet18

g = {}
def fake_wrapper(gtype):
    def fake_compiler(fx_g, inps):
        print(fx_g.code)
        nonlocal gtype
        g[gtype] = fx_g
        return fx_g
    return fake_compiler


inp = torch.randn(32, 3, 224, 224, dtype=torch.float32).cuda()
targets = torch.zeros(32, dtype=int).cuda()

b_models:List[torch.nn.Module] = [resnet18().cuda() for _ in range(5)]

func_model, params, buffers = combine_state_for_ensemble(b_models)
for p in params:
    p.requires_grad = True

def compute_loss(weights, buffers, batch, targets):
    output = func_model(weights, buffers, batch)
    loss = torch.nn.functional.nll_loss(output,targets)
    return loss
parallel_func = vmap(compute_loss, in_dims=(0,0,None, None))
aot_func = aot_function(parallel_func, fake_wrapper("forward"), fake_wrapper("backward"))
out = aot_func(params, buffers, inp, targets)
out.mean().backward()

Error:
RuntimeError: !self.requires_grad() || self.is_contiguous() INTERNAL ASSERT FAILED at “/scratch/anijain/work/pytorch/aten/src/ATen/native/TensorShape.cpp”:3609, please report a bug to PyTorch. as_strided_scatter is currently only supported for contiguous inputs While executing %new_empty_strided_1 : [#users=1] = call_function[target=torch.ops.aten.new_empty_strided.default](args = (%copy__1, [5, 512, 32, 7, 7], [25088, 49, 125440, 7, 1]), kwargs = {}) Original traceback: Gradient addition node due to multiple use of tensor around:

cc @zou3519 @bdhirsh @ezyang @msaroufim @wconstab @anijain2305 @soumith

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: aotdispatchumbrella label for AOTAutograd issuesmodule: functionalizationused for issues that are specific to functionalization (AOTAutograd bugs should start w aotdispatch)module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,module: vmaponcall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions