Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from typing import Callable, Dict, Sequence, Type
from collections import defaultdict
from typing import Callable, Dict, List, Sequence, Type

import pytest
import torchvision.prototype.transforms.functional as F
Expand All @@ -11,15 +12,30 @@
KERNEL_SAMPLE_INPUTS_FN_MAP = {info.kernel: info.sample_inputs_fn for info in KERNEL_INFOS}


def skip_python_scalar_arg_jit(name, *, reason="Python scalar int or float is not supported when scripting"):
return Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs[name], (int, float)),
reason=reason,
)


def skip_integer_size_jit(name="size"):
return skip_python_scalar_arg_jit(name, reason="Integer size is not supported when scripting.")


@dataclasses.dataclass
class DispatcherInfo:
dispatcher: Callable
kernels: Dict[Type, Callable]
skips: Sequence[Skip] = dataclasses.field(default_factory=list)
_skips_map: Dict[str, Skip] = dataclasses.field(default=None, init=False)
_skips_map: Dict[str, List[Skip]] = dataclasses.field(default=None, init=False)

def __post_init__(self):
self._skips_map = {skip.test_name: skip for skip in self.skips}
skips_map = defaultdict(list)
for skip in self.skips:
skips_map[skip.test_name].append(skip)
self._skips_map = dict(skips_map)

def sample_inputs(self, *types):
for type in types or self.kernels.keys():
Expand All @@ -29,9 +45,13 @@ def sample_inputs(self, *types):
yield from KERNEL_SAMPLE_INPUTS_FN_MAP[self.kernels[type]]()

def maybe_skip(self, *, test_name, args_kwargs, device):
skip = self._skips_map.get(test_name)
if skip and skip.condition(args_kwargs, device):
pytest.skip(skip.reason)
skips = self._skips_map.get(test_name)
if not skips:
return

for skip in skips:
if skip.condition(args_kwargs, device):
pytest.skip(skip.reason)


DISPATCHER_INFOS = [
Expand All @@ -50,6 +70,9 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.resize_bounding_box,
features.Mask: F.resize_mask,
},
skips=[
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These skips currently don't do anything since we don't have integer sizes, but will be necessary as soon that happens.

skip_integer_size_jit(),
],
),
DispatcherInfo(
F.affine,
Expand All @@ -58,6 +81,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.affine_bounding_box,
features.Mask: F.affine_mask,
},
skips=[skip_python_scalar_arg_jit("shear", reason="Scalar shear is not supported by JIT")],
),
DispatcherInfo(
F.vertical_flip,
Expand Down Expand Up @@ -122,12 +146,19 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.BoundingBox: F.center_crop_bounding_box,
features.Mask: F.center_crop_mask,
},
skips=[
skip_integer_size_jit("output_size"),
],
),
DispatcherInfo(
F.gaussian_blur,
kernels={
features.Image: F.gaussian_blur_image_tensor,
},
skips=[
skip_python_scalar_arg_jit("kernel_size"),
skip_python_scalar_arg_jit("sigma"),
],
),
DispatcherInfo(
F.equalize,
Expand Down Expand Up @@ -207,11 +238,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.Image: F.five_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting five_crop_image_tensor.",
),
skip_integer_size_jit(),
],
),
DispatcherInfo(
Expand All @@ -220,11 +247,7 @@ def maybe_skip(self, *, test_name, args_kwargs, device):
features.Image: F.ten_crop_image_tensor,
},
skips=[
Skip(
"test_scripted_smoke",
condition=lambda args_kwargs, device: isinstance(args_kwargs.kwargs["size"], int),
reason="Integer size is not supported when scripting ten_crop_image_tensor.",
),
skip_integer_size_jit(),
],
),
DispatcherInfo(
Expand Down
Loading