-
Notifications
You must be signed in to change notification settings - Fork 26.3k
General per-SampleInput xfail / skip system #140443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140443
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (2 Unrelated Failures)As of commit ee8d0d8 with merge base e6c5a77 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. The key goal of this PR is to maintain clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. How does it work? Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SAMPLE_RULES = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_rules=FOO_SAMPLE_RULES) # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op def test_foo(self, device, dtype, op, sample_rules): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_rules=sample_rules ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` "Rules" above can only be xfails or skips. More examples can be seen in `test/test_nestedtensor.py`, where this stuff is used in practice. There is also some logging of matched rules for debugging purposes accessible by setting the loglevel to `DEBUG`. [ghstack-poisoned]
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. The key goal of this PR is to maintain clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. How does it work? Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SAMPLE_RULES = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_rules=FOO_SAMPLE_RULES) # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op def test_foo(self, device, dtype, op, sample_rules): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_rules=sample_rules ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` "Rules" above can only be xfails or skips. More examples can be seen in `test/test_nestedtensor.py`, where this stuff is used in practice. There is also some logging of matched rules for debugging purposes accessible by setting the loglevel to `DEBUG`. [ghstack-poisoned]
|
|
||
|
|
||
| # setup logging | ||
| log = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh do we have logging as our normal print infra for tests now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I'm probably doing this wrong then. What's the right way to do a optional debug-only log entry?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This i have no idea....maybe this is the way? cc @zou3519
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. (Note: this term might be too general, making this more confusing than it needs to be; please suggest alternatives). ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_rules` that specify xfails / skips. * The list of `sample_rules` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SAMPLE_RULES = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_rules=FOO_SAMPLE_RULES) # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op def test_foo(self, device, dtype, op, sample_rules): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_rules=sample_rules ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
janeyx99
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving as this PR is almost good (we're agreed on naming + concepts) and cuz I'll be out the next 3 weeks. However, please do rename + document the samplerules more clearly, so that people can quickly pick up what's going on when they're looking for it.
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_rules` that specify xfails / skips. * The list of `sample_rules` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_skips_and_xfails=FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op, sample_skips_and_xfails): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_skips_and_xfails here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_skips_and_xfails=sample_skips_and_xfails ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
test/test_nestedtensor.py
Outdated
| @ops( | ||
| [op for op in njt_op_db if op.supports_njt and op.supports_autograd], | ||
| allowed_dtypes=(torch.float32,), | ||
| sample_skips_and_xfails=BACKWARD_SKIPS_AND_XFAILS, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General feedback: I can see this being useful. It's not directly useful right now my usecases (un-skipping some functorch tests). That being said it's probably not hard to build this in to that.
Some questions I have are: skips and xfails are generally on the OpInfo objects themselves. Are you thinking of having a variant that can look like that? It would be cool to have a single decorator you can apply to a Function for an OpInfo that makes one of these rules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the PR to store skips / xfails on the test function via the @sample_skips_and_xfails decorator. @ops reads and filters these, placing skips / xfails relevant to each op onto the parametrized test functions. Then the sample input generation funcs (e.g. op.sample_inputs()) pull these off of the test function to establish the right subtest context.
There are minimal changes to the test logic needed now, but still some: it's necessary to pass op.sample_inputs(..., use_subtests=True) when skips / xfails are defined so they can be handled appropriately. Then the iterator returns (sample, subtest_ctx) and the subtest_ctx must be entered.
Some questions I have are: skips and xfails are generally on the OpInfo objects themselves. Are you thinking of having a variant that can look like that? It would be cool to have a single decorator you can apply to a Function for an OpInfo that makes one of these rules
With the above changes, it's possible to define an xfailIf decorator that is just syntactic sugar over the rule system. I added an example using this for un-skipping a vmap test with a more granular xfail based on a SampleInput property.
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_skips_and_xfails=FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op, sample_skips_and_xfails): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_skips_and_xfails here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_skips_and_xfails=sample_skips_and_xfails ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
| xfailIf( | ||
| "to", | ||
| lambda sample: ( | ||
| sample.kwargs["memory_format"] == torch.channels_last | ||
| ), | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is awesome
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db) sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db) sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks awesome
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See pytorch#89354 and pytorch#82669 for some requests asking for this type of functionality. This was originally landed for NJT in pytorch#138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: pytorch#140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: pytorch#140160, pytorch#138370
TSIA Pull Request resolved: pytorch#141088 Approved by: https://github.com/atalman
ghstack-source-id: dd35f8f Pull Request resolved: pytorch/pytorch#140443
Stack from ghstack (oldest at bottom):
Background
This PR adds the functionality to xfail / skip on a per-
SampleInputbasis forOpInfotests. See #89354 and #82669 for some requests asking for this type of functionality.This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.
Design
Principles
SampleInputgeneration logic, test logic that uses theSampleInputs, and xfail / skip logic (which will change as bugs are addressed).Details
The core new concept is a
SampleRule, which can be either anXFailRuleorSkipRule.sample_skips_and_xfails.sample_skips_and_xfailsis traversed in order, and the first rule that matches (if any) is applied, so order can matter.DEBUG.op_match_fnandsample_match_fnis made to allow pre-filtering of the list of rules to get only those that apply to the op under test.SampleInputis run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroringSampleInput; all samples will be run through test logic.Example Usage
Consider the following OpInfo test:
This is a common pattern for such tests; simply generate a list of
SampleInputsand run them through the op. Now say you want to xfail one of theseSampleInputs for a given op. Today, you have to xfail the entire test or hack around this in the test logic.This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
More examples can be seen in
test/test_nestedtensor.py, where this system is used in practice.I also demonstrate usage of syntactic sugar over this system in
test/functorch/test_vmap.py. Here, a skip for theto()operator is replaced with a granular xfail fortest_vmap_exhaustive():