-
Notifications
You must be signed in to change notification settings - Fork 26.3k
NJT OpInfo tests v2 #138370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NJT OpInfo tests v2 #138370
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138370
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit dc5e500 with merge base e6c5a77 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Awesome work!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. The key goal of this PR is to maintain clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. How does it work? Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SAMPLE_RULES = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_rules=FOO_SAMPLE_RULES) # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op def test_foo(self, device, dtype, op, sample_rules): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_rules=sample_rules ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` "Rules" above can only be xfails or skips. More examples can be seen in `test/test_nestedtensor.py`, where this stuff is used in practice. There is also some logging of matched rules for debugging purposes accessible by setting the loglevel to `DEBUG`. [ghstack-poisoned]
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. The key goal of this PR is to maintain clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. How does it work? Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SAMPLE_RULES = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_rules=FOO_SAMPLE_RULES) # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op def test_foo(self, device, dtype, op, sample_rules): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_rules=sample_rules ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` "Rules" above can only be xfails or skips. More examples can be seen in `test/test_nestedtensor.py`, where this stuff is used in practice. There is also some logging of matched rules for debugging purposes accessible by setting the loglevel to `DEBUG`. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. (Note: this term might be too general, making this more confusing than it needs to be; please suggest alternatives). ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_rules` that specify xfails / skips. * The list of `sample_rules` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SAMPLE_RULES = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_rules=FOO_SAMPLE_RULES) # NB: the ops decorator automatically filters out any sample_rules that don't apply to this op def test_foo(self, device, dtype, op, sample_rules): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_rules here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_rules=sample_rules ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_rules` that specify xfails / skips. * The list of `sample_rules` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_skips_and_xfails=FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op, sample_skips_and_xfails): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_skips_and_xfails here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_skips_and_xfails=sample_skips_and_xfails ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db, sample_skips_and_xfails=FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op, sample_skips_and_xfails): for sample, subtest_ctx in op.sample_inputs( # NB: passing sample_skips_and_xfails here enables the opt-in functionality to get subtest xfails / skips device, dtype, requires_grad=False, sample_skips_and_xfails=sample_skips_and_xfails ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db) sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python dataclass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" dataclass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" dataclass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): ops(op_db) sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. [ghstack-poisoned]
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: #140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: #140160, #138370
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See pytorch#89354 and pytorch#82669 for some requests asking for this type of functionality. This was originally landed for NJT in pytorch#138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: pytorch#140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: pytorch#140160, pytorch#138370
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@DataClass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: pytorch#138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: pytorch#140160
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See pytorch#89354 and pytorch#82669 for some requests asking for this type of functionality. This was originally landed for NJT in pytorch#138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: pytorch#140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: pytorch#140160, pytorch#138370
ghstack-source-id: ab33ab3 Pull Request resolved: pytorch/pytorch#138370
Stack from ghstack (oldest at bottom):
This PR updates OpInfo-based tests for NJTs:
_sample_njts()helper thatsample_input_funcs utilize now produces non-contig NJTs as wellSampleInput-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.SampleInputs and utilizing this system over using test parametrization for two reasons:SampleInputs is faster than generating entire new testssample_input_funcs not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility forSampleInputgeneration firmly within eachOpInfo'ssample_input_funcmeans weirdness like this isn't possibleSampleInputnaming for a bunch ofsample_input_funcs. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new_describe_njt()helper to get a string repr that uniquely defines the type of NJT being passed to the opXFailRules to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixedExample: