Skip to content

Conversation

@jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Nov 12, 2024

Stack from ghstack (oldest at bottom):

Background

This PR adds the functionality to xfail / skip on a per-SampleInput basis for OpInfo tests. See #89354 and #82669 for some requests asking for this type of functionality.

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

Design

Principles

  • Clean separation among SampleInput generation logic, test logic that uses the SampleInputs, and xfail / skip logic (which will change as bugs are addressed).
  • Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    • This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
  • Opt-in with minimal test logic changes + no substantial impact on other tests.

Details

The core new concept is a SampleRule, which can be either an XFailRule or SkipRule.

@dataclass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

@dataclass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

@dataclass
class SkipRule(SampleRule):
    ...
  • See below for example usage details, but at a high level: each test should have a corresponding list of sample_skips_and_xfails.
    • The list of sample_skips_and_xfails is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    • The PR includes a logging mechanism for matched rules accessible by setting the loglevel to DEBUG.
  • The split between op_match_fn and sample_match_fn is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
  • Each SampleInput is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring SampleInput; all samples will be run through test logic.

Example Usage

Consider the following OpInfo test:

class MyTestCase(TestCase):
    @ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...

This is a common pattern for such tests; simply generate a list of SampleInputs and run them through the op. Now say you want to xfail one of these SampleInputs for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:

# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    @ops(op_db)
    @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
    # NB: the @ops decorator automatically filters out any rules that don't apply to this op
    def test_foo(self, device, dtype, op):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
            # an informative error will be thrown.
            device, dtype, requires_grad=False, use_subtests=True
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...

More examples can be seen in test/test_nestedtensor.py, where this system is used in practice.

I also demonstrate usage of syntactic sugar over this system in test/functorch/test_vmap.py. Here, a skip for the to() operator is replaced with a granular xfail for test_vmap_exhaustive():

...
# pre-existing xfail
xfail("item"),
# new granular xfail using syntactic sugar over the general system
xfailIf(
    "to",
    lambda sample: (
        sample.kwargs["memory_format"] == torch.channels_last
    ),
),
...

@jbschlosser jbschlosser requested review from a team and mruberry as code owners November 12, 2024 20:06
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140443

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit ee8d0d8 with merge base e6c5a77 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jbschlosser jbschlosser added module: testing Issues related to the torch.testing module (not tests) topic: not user facing topic category labels Nov 12, 2024
@jbschlosser jbschlosser requested review from janeyx99 and removed request for mruberry November 12, 2024 20:09
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. The key goal of this PR is to maintain clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

How does it work? Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SAMPLE_RULES = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db, sample_rules=FOO_SAMPLE_RULES)
    # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op
    def test_foo(self, device, dtype, op, sample_rules):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips
            device, dtype, requires_grad=False, sample_rules=sample_rules
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

"Rules" above can only be xfails or skips.

More examples can be seen in `test/test_nestedtensor.py`, where this stuff is used in practice. There is also some logging of matched rules for debugging purposes accessible by setting the loglevel to `DEBUG`.

[ghstack-poisoned]
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. The key goal of this PR is to maintain clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

How does it work? Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SAMPLE_RULES = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db, sample_rules=FOO_SAMPLE_RULES)
    # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op
    def test_foo(self, device, dtype, op, sample_rules):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips
            device, dtype, requires_grad=False, sample_rules=sample_rules
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

"Rules" above can only be xfails or skips.

More examples can be seen in `test/test_nestedtensor.py`, where this stuff is used in practice. There is also some logging of matched rules for debugging purposes accessible by setting the loglevel to `DEBUG`.

[ghstack-poisoned]


# setup logging
log = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh do we have logging as our normal print infra for tests now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I'm probably doing this wrong then. What's the right way to do a optional debug-only log entry?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This i have no idea....maybe this is the way? cc @zou3519

@jbschlosser jbschlosser added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2024
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.

#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. (Note: this term might be too general, making this more confusing than it needs to be; please suggest alternatives).

```python
dataclass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

dataclass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

dataclass
class SkipRule(SampleRule):
    ...
```

* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_rules` that specify xfails / skips.
    * The list of `sample_rules` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.

### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SAMPLE_RULES = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db, sample_rules=FOO_SAMPLE_RULES)
    # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op
    def test_foo(self, device, dtype, op, sample_rules):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips
            device, dtype, requires_grad=False, sample_rules=sample_rules
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.

[ghstack-poisoned]
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving as this PR is almost good (we're agreed on naming + concepts) and cuz I'll be out the next 3 weeks. However, please do rename + document the samplerules more clearly, so that people can quickly pick up what's going on when they're looking for it.

@jbschlosser jbschlosser requested a review from zou3519 November 14, 2024 14:58
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.

#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.

```python
dataclass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

dataclass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

dataclass
class SkipRule(SampleRule):
    ...
```

* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_rules` that specify xfails / skips.
    * The list of `sample_rules` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.

### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db, sample_skips_and_xfails=FOO_SKIPS_AND_XFAILS)
    # NB: the ops decorator automatically filters out any rules that don't apply to this op
    def test_foo(self, device, dtype, op, sample_skips_and_xfails):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: passing sample_skips_and_xfails here enables the opt-in functionality to get subtest xfails / skips
            device, dtype, requires_grad=False, sample_skips_and_xfails=sample_skips_and_xfails
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.

[ghstack-poisoned]
@ops(
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
allowed_dtypes=(torch.float32,),
sample_skips_and_xfails=BACKWARD_SKIPS_AND_XFAILS,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General feedback: I can see this being useful. It's not directly useful right now my usecases (un-skipping some functorch tests). That being said it's probably not hard to build this in to that.

Some questions I have are: skips and xfails are generally on the OpInfo objects themselves. Are you thinking of having a variant that can look like that? It would be cool to have a single decorator you can apply to a Function for an OpInfo that makes one of these rules

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR to store skips / xfails on the test function via the @sample_skips_and_xfails decorator. @ops reads and filters these, placing skips / xfails relevant to each op onto the parametrized test functions. Then the sample input generation funcs (e.g. op.sample_inputs()) pull these off of the test function to establish the right subtest context.

There are minimal changes to the test logic needed now, but still some: it's necessary to pass op.sample_inputs(..., use_subtests=True) when skips / xfails are defined so they can be handled appropriately. Then the iterator returns (sample, subtest_ctx) and the subtest_ctx must be entered.

Some questions I have are: skips and xfails are generally on the OpInfo objects themselves. Are you thinking of having a variant that can look like that? It would be cool to have a single decorator you can apply to a Function for an OpInfo that makes one of these rules

With the above changes, it's possible to define an xfailIf decorator that is just syntactic sugar over the rule system. I added an example using this for un-skipping a vmap test with a more granular xfail based on a SampleInput property.

### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.

#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.

```python
dataclass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

dataclass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

dataclass
class SkipRule(SampleRule):
    ...
```

* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
    * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.

### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db, sample_skips_and_xfails=FOO_SKIPS_AND_XFAILS)
    # NB: the ops decorator automatically filters out any rules that don't apply to this op
    def test_foo(self, device, dtype, op, sample_skips_and_xfails):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: passing sample_skips_and_xfails here enables the opt-in functionality to get subtest xfails / skips
            device, dtype, requires_grad=False, sample_skips_and_xfails=sample_skips_and_xfails
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.

[ghstack-poisoned]
Comment on lines +4383 to +4388
xfailIf(
"to",
lambda sample: (
sample.kwargs["memory_format"] == torch.channels_last
),
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is awesome

### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.

#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.

```python
dataclass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

dataclass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

dataclass
class SkipRule(SampleRule):
    ...
```

* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
    * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.

### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db)
    sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
    # NB: the ops decorator automatically filters out any rules that don't apply to this op
    def test_foo(self, device, dtype, op):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
            # an informative error will be thrown.
            device, dtype, requires_grad=False, use_subtests=True
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.

[ghstack-poisoned]
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.

This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.

### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.

#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.

```python
dataclass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

dataclass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

dataclass
class SkipRule(SampleRule):
    ...
```

* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
    * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.

### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and 
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    ops(op_db)
    sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
    # NB: the ops decorator automatically filters out any rules that don't apply to this op
    def test_foo(self, device, dtype, op):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
            # an informative error will be thrown.
            device, dtype, requires_grad=False, use_subtests=True
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.

[ghstack-poisoned]
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks awesome

@jbschlosser
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

This was referenced Nov 21, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See pytorch#89354 and pytorch#82669 for some requests asking for this type of functionality.

This was originally landed for NJT in pytorch#138370 and is generalized and slightly tweaked here.

### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
    * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.

#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.

```python
@DataClass
class SampleRule(ABC):
    # function to indicate whether the rule applies to this op; return True if so
    # NB: str arg of callable is device_type
    op_match_fn: Callable[[str, OpInfo], bool] = None
    # function to indicate whether the rule applies to this sample; return True if so
    sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
    # optional name for identifying the rule
    name: str = ""

@DataClass
class XFailRule(SampleRule):
    # expected error type
    error_type: TypeVar = Exception
    # expected error message
    error_msg: str = ".*"

@DataClass
class SkipRule(SampleRule):
    ...
```

* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
    * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
    * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.

### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
    @ops(op_db)
    def test_foo(self, device, dtype, op):
        for sample in op.sample_inputs(device, dtype, requires_grad=False):
            # do some SampleInput-based test logic
            output = op.op(sample.input, *sample.args, **sample.kwargs)
            ...
```

This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.

This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
    XFailRule(
        error_type=ValueError,
        error_mg="2D inputs not supported",
        op_match_fn=lambda device, op: (
            # NB: logic for which ops this rule applies to goes here
            op.full_name == "add"
        ),
        sample_match_fn=lambda device, sample: (
            # NB: logic which samples this rule applies to goes here
            sample.input.dim() == 2
        ),
        # NB: optional rule identifier can help with debugging matched rules
        name="add_with_2D_inputs_not_supported",
    ),
    # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
    # this skips a particular SampleInput instead of xfailing :)
    SkipRule(...),
    ...
]

class MyTestCase(TestCase):
    @ops(op_db)
    @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
    # NB: the @ops decorator automatically filters out any rules that don't apply to this op
    def test_foo(self, device, dtype, op):
        for sample, subtest_ctx in op.sample_inputs(
            # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
            # an informative error will be thrown.
            device, dtype, requires_grad=False, use_subtests=True
        ):
            # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
            with subtest_ctx(self):
                # do some SampleInput-based test logic
                output = op.op(sample.input, *sample.args, **sample.kwargs)
                ...
```

More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.

I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`:
```python
...
# pre-existing xfail
xfail("item"),
# new granular xfail using syntactic sugar over the general system
xfailIf(
    "to",
    lambda sample: (
        sample.kwargs["memory_format"] == torch.channels_last
    ),
),
...
```
Pull Request resolved: pytorch#140443
Approved by: https://github.com/janeyx99, https://github.com/zou3519
ghstack dependencies: pytorch#140160, pytorch#138370
Esquains pushed a commit to Esquains/study1 that referenced this pull request Dec 15, 2024
@github-actions github-actions bot deleted the gh/jbschlosser/199/head branch December 20, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: testing Issues related to the torch.testing module (not tests) topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants