Skip to content

Commit fcaf8e3

Browse files
committed
Update on "[AOTI] Fix a fallback op returning None"
Summary: Fixes #135781. In some cases, a fallback can return None in the place of a tensor. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames chauhang [ghstack-poisoned]
2 parents 043e644 + 3c4f67c commit fcaf8e3

File tree

5 files changed

+54
-15
lines changed

5 files changed

+54
-15
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3301,6 +3301,40 @@ def forward(self, x, y):
33013301
Model(), example_inputs, options=dict(max_autotune=max_autotune)
33023302
)
33033303

3304+
@skip_if_no_torchvision
3305+
def test_torchvision_transforms_functional_tensor_resize(self):
3306+
import torchvision
3307+
3308+
# https://fb.workplace.com/groups/1075192433118967/permalink/1501860707118802/
3309+
class A(torch.nn.Module):
3310+
def forward(self, image: torch.Tensor, target_size: torch.Tensor):
3311+
target_h, target_w = target_size.tolist()
3312+
torch._check(target_h > 0)
3313+
torch._check(target_w > 0)
3314+
torch._check(target_h <= 4000)
3315+
torch._check(target_w <= 4000)
3316+
3317+
return torchvision.transforms._functional_tensor.resize(
3318+
image,
3319+
size=[target_h, target_w],
3320+
interpolation="bilinear",
3321+
antialias=False,
3322+
)
3323+
3324+
model = A()
3325+
example_inputs = (
3326+
torch.ones([3, 800, 600], device=self.device),
3327+
torch.tensor([448, 336], device=self.device),
3328+
)
3329+
dynamic_shapes = {
3330+
"image": {
3331+
1: torch.export.Dim("height", min=1, max=4000),
3332+
2: torch.export.Dim("width", min=1, max=4000),
3333+
},
3334+
"target_size": None,
3335+
}
3336+
self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
3337+
33043338
def test_aoti_debug_printer_codegen(self):
33053339
# basic addmm model to test codegen for aoti intermediate debug printer
33063340
class Model(torch.nn.Module):
@@ -3627,6 +3661,7 @@ def fail_non_abi_compatible_cuda(is_skip=False):
36273661
is_skip=True
36283662
),
36293663
"test_size_from_multi_output": fail_stack_allocation(is_skip=True),
3664+
"test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(),
36303665
}
36313666

36323667
# test_failures, xfail by default, set is_skip=True to skip

test/inductor/test_cpu_cpp_wrapper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase):
9696
f"{test_name}_dynamic_shapes"
9797
] = test_torchinductor.TestFailure(("cpp_wrapper",), is_skip=False)
9898
skip_list = [
99-
"test_multihead_attention_cpu",
10099
*[
101100
func
102101
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())

test/test_matmul_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def test_float8_scale_fast_accum(self, device) -> None:
560560
self.assertEqual(out_fp8, out_fp8_s)
561561

562562
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
563+
@unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific")
563564
@skipIfRocm()
564565
@parametrize("use_fast_accum", [True, False])
565566
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:

torch/_inductor/compile_fx.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.fx
1818
import torch.utils._pytree as pytree
1919
from functorch.compile import min_cut_rematerialization_partition
20+
from torch._dispatch.python import enable_python_dispatcher
2021
from torch._dynamo import (
2122
compiled_autograd,
2223
config as dynamo_config,
@@ -400,20 +401,22 @@ def fake_tensor_prop(
400401
401402
The created fake mode will be returned.
402403
"""
403-
fake_mode = detect_fake_mode(example_inputs)
404-
if not fake_mode:
405-
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
406-
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
407-
else:
408-
ctx = (
409-
contextlib.nullcontext()
410-
if not force_allow_non_fake_inputs
411-
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
412-
)
413-
with ctx: # type: ignore[attr-defined]
414-
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
415-
*example_inputs
404+
# Ensure that decomps that support symbolic shapes are used
405+
with enable_python_dispatcher():
406+
fake_mode = detect_fake_mode(example_inputs)
407+
if not fake_mode:
408+
fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
409+
FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs)
410+
else:
411+
ctx = (
412+
contextlib.nullcontext()
413+
if not force_allow_non_fake_inputs
414+
else mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
416415
)
416+
with ctx: # type: ignore[attr-defined]
417+
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(
418+
*example_inputs
419+
)
417420

418421
return fake_mode
419422

torch/_subclasses/functional_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
77

88
import torch
9-
import torch._inductor.config as inductor_config
109
import torch.utils._pytree as pytree
1110
from torch._C import _functionalization_reapply_views_tls as _reapply_views
1211
from torch._ops import _get_dispatch_mode_pre_dispatch
@@ -471,6 +470,8 @@ def unwrap(x):
471470
# it doesn't matter what mode we use here because
472471
# the implementation of do_auto_functionalize doesn't
473472
# interact with FunctionalTensorMode at all
473+
import torch._inductor.config as inductor_config
474+
474475
if self.export or not inductor_config.enable_auto_functionalized_v2:
475476
return do_auto_functionalize(func, args, kwargs)
476477
else:

0 commit comments

Comments
 (0)