Skip to content

out= meta device support. #138396

@ysiraichi

Description

@ysiraichi

List of operations, whose out= variants are not consistent with eager (i.e. run on CPU/CUDA, but fail when using meta devices). I have grouped them according to the error each of them raise.

No Meta Kernel Registered

  • _native_batch_norm_legit
  • geqrf
Error Example
Traceback (most recent call last):
  File "examples/ops.py", line 88, in run
    f(input_, *args_, **kwargs_, out=out)
  File "torch/testing/_internal/opinfo/core.py", line 1169, in __call__
    return self.op(*args, **kwargs)
NotImplementedError: aten::_native_batch_norm_legit.out: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl. Please see the following for next steps:  https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "torch/testing/_internal/common_device_type.py", line 1140, in test_wrapper
    return test(*args, **kwargs)
  File "torch/testing/_internal/common_device_type.py", line 1371, in only_fn
    return fn(slf, *args, **kwargs)
  File "examples/ops.py", line 96, in test_meta_out
    raise RuntimeError(f"eager didn't fail, but meta did.") from meta_err
RuntimeError: eager didn't fail, but meta did.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/lib/python3.9/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/lib/python3.9/unittest/case.py", line 592, in run
    self._callTestMethod(testMethod)
  File "/lib/python3.9/unittest/case.py", line 550, in _callTestMethod
    method()
  File "torch/testing/_internal/common_utils.py", line 2983, in wrapper
    method(*args, **kwargs)
  File "torch/testing/_internal/common_utils.py", line 2983, in wrapper
    method(*args, **kwargs)
  File "torch/testing/_internal/common_device_type.py", line 448, in instantiated_test
    result = test(self, **param_kwargs)
  File "torch/testing/_internal/common_utils.py", line 1530, in wrapper
    fn(*args, **kwargs)
  File "torch/testing/_internal/common_device_type.py", line 1152, in test_wrapper
    raise e_tracked from e
Exception: Caused by sample input at index 9: SampleInput(input=Tensor[size=(1, 2, 3), device="cuda:0", dtype=torch.float32], args=(None,None,True,0.5,1e-05), kwargs={}, broadcasts_input=False, name='')

To execute this test, run the following from the base repo dir:
    PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=9 python ops.py TestCommonCUDA.test_meta_out__native_batch_norm_legit_cuda_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Other Operations

  • nanmean
Error Traceback
Traceback (most recent call last):
  File "examples/ops.py", line 88, in run
    f(input_, *args_, **kwargs_, out=out)
  File "torch/testing/_internal/opinfo/core.py", line 1169, in __call__
    return self.op(*args, **kwargs)
RuntimeError: DispatchStub: unsupported device typemeta

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "torch/testing/_internal/common_device_type.py", line 1140, in test_wrapper
    return test(*args, **kwargs)
  File "torch/testing/_internal/common_device_type.py", line 1371, in only_fn
    return fn(slf, *args, **kwargs)
  File "examples/ops.py", line 96, in test_meta_out
    raise RuntimeError(f"eager didn't fail, but meta did.") from meta_err
RuntimeError: eager didn't fail, but meta did.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/lib/python3.9/unittest/case.py", line 59, in testPartExecutor
    yield
  File "/lib/python3.9/unittest/case.py", line 592, in run
    self._callTestMethod(testMethod)
  File "/lib/python3.9/unittest/case.py", line 550, in _callTestMethod
    method()
  File "torch/testing/_internal/common_utils.py", line 2983, in wrapper
    method(*args, **kwargs)
  File "torch/testing/_internal/common_utils.py", line 2983, in wrapper
    method(*args, **kwargs)
  File "torch/testing/_internal/common_device_type.py", line 448, in instantiated_test
    result = test(self, **param_kwargs)
  File "torch/testing/_internal/common_utils.py", line 1530, in wrapper
    fn(*args, **kwargs)
  File "torch/testing/_internal/common_device_type.py", line 1152, in test_wrapper
    raise e_tracked from e
Exception: Caused by sample input at index 34: SampleInput(input=Tensor[size=(2, 2), device="cuda:0", dtype=torch.float32], args=(), kwargs={'dim': '(0,-1)', 'keepdim': 'True'}, broadcasts_input=False, name='')

To execute this test, run the following from the base repo dir:
    PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=34 python ops.py TestCommonCUDA.test_meta_out_nanmean_cuda_float32

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

Dynamic Shape Output

The operations listed below return tensors of dynamic shape. Which means that it's impossible to know its shape (i.e. implement a meta function) without the actual data.

  • linalg_lstsq
  • masked_select
  • nonzero

Test Setup

In order to reproduce these results, besides the actual test below, we needed to make wrapper_set_seed function a no-op:

--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -40,8 +40,9 @@ from torch.testing._internal.common_utils import (
     GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW,
     TEST_WITH_TORCHINDUCTOR
 )
-from torch.testing._utils import wrapper_set_seed
+# from torch.testing._utils import wrapper_set_seed
 
+import torch
 import torch._refs as refs  # noqa: F401
 import torch._refs.nn.functional
 import torch._refs.special
@@ -50,6 +51,9 @@ import torch._prims as prims  # noqa: F401
 from torch.utils import _pytree as pytree
 
 
+def wrapper_set_seed(op, *args, **kwargs):
+    return op(*args, **kwargs)
+
 from packaging import version
 
 from torch.testing._internal.opinfo.core import (  # noqa: F401
-- 
2.47.0
OpInfo Test
import torch
import torch.utils._pytree as pytree
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests, OpDTypes, onlyCUDA, onlyCPU
from torch.testing._internal.common_utils import TestCase, run_tests

class TestCommon(TestCase):

    @ops([op for op in op_db if op.supports_out], allowed_dtypes=(torch.float32,))
    def test_dynamo_out(self, device, dtype, op):
         samples = list(op.sample_inputs(device, dtype))

        for i, sample in enumerate(samples):
            torch._dynamo.reset()
            input, args, kwargs = (sample.input, sample.args, sample.kwargs)

            # Run the functional version of the operation, using eager.
            try:
                expected = op(input, *args, **kwargs)

                if isinstance(expected, tuple):
                    expected = tuple(expected)
            except:
                # If that doesn't work out, go to the next sample.
                continue

            def run(f, dev):
                # Create new outputs in the desired device.
                out = pytree.tree_map_only(torch.Tensor, lambda t: torch.empty_like(t, device=dev), expected)

                # Move inputs to the desired device
                stuff = (input, args, kwargs)
                stuff = pytree.tree_map_only(torch.Tensor, lambda t: t.to(dev), stuff)
                stuff = pytree.tree_map_only(torch.device, lambda d: torch.device(dev), stuff)
                stuff = pytree.tree_map_only(str, lambda v: dev if v == device else v, stuff)
                input_, args_, kwargs_ = stuff

                # Try running the operation, and return the raised error, if any.
                try:
                    f(input_, *args_, **kwargs_, out=out)
                except Exception as e:
                    return e

            eager_err = run(op, device)
            meta_err = run(op, "meta")

            if eager_err is None and meta_err is not None:
                raise RuntimeError(f"eager didn't fail, but meta did.") from meta_err
            elif eager_err is not None and meta_err is None:
                raise RuntimeError(f"eager failed, but meta didn't.") from eager_err

instantiate_device_type_tests(TestCommon, globals())

if __name__ == "__main__":
    run_tests()

Versions

PyTorch version: 2.5.0a0+git7128504
Is debug build: True
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

cc @ezyang @eellison @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: meta tensorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions