-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Misc. non-contig NJT fixes #140160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Misc. non-contig NJT fixes #140160
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140160
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 892d613 with merge base e6c5a77 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cpuhrsch
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR contains several fixes related to non-contiguous NJTs: 1. Propagates `lengths` through op calls appropriately (see desc of #138098) * SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()` 2. Allows non-contig NJTs in unsqueeze / transpose / select 3. Expands padded dense -> NJT conversion to support non-contig NJTs 4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature [ghstack-poisoned]
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@DataClass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: #138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality. This was originally landed for NJT in #138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: #140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: #140160, #138370
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See pytorch#89354 and pytorch#82669 for some requests asking for this type of functionality. This was originally landed for NJT in pytorch#138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: pytorch#140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: pytorch#140160, pytorch#138370
This PR contains several fixes related to non-contiguous NJTs: 1. Propagates `lengths` through op calls appropriately (see desc of pytorch#138098) * SDPA now calls `nested_view_from_values_offsets_lengths()` instead of `nested_view_from_values_offsets()` 2. Allows non-contig NJTs in unsqueeze / transpose / select 3. Expands padded dense -> NJT conversion to support non-contig NJTs 4. (unrelated sorry) Updates `split` / `split_with_sizes` to allow for optional `dim`, matching the ATen signature Pull Request resolved: pytorch#140160 Approved by: https://github.com/cpuhrsch
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@DataClass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: pytorch#138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: pytorch#140160
### Background This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See pytorch#89354 and pytorch#82669 for some requests asking for this type of functionality. This was originally landed for NJT in pytorch#138370 and is generalized and slightly tweaked here. ### Design #### Principles * Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed). * Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops. * This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example. * Opt-in with minimal test logic changes + no substantial impact on other tests. #### Details The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`. ```python @DataClass class SampleRule(ABC): # function to indicate whether the rule applies to this op; return True if so # NB: str arg of callable is device_type op_match_fn: Callable[[str, OpInfo], bool] = None # function to indicate whether the rule applies to this sample; return True if so sample_match_fn: Callable[[torch.device, SampleInput], bool] = None # optional name for identifying the rule name: str = "" @DataClass class XFailRule(SampleRule): # expected error type error_type: TypeVar = Exception # expected error message error_msg: str = ".*" @DataClass class SkipRule(SampleRule): ... ``` * See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`. * The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter. * The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`. * The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test. * Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic. ### Example Usage Consider the following OpInfo test: ```python class MyTestCase(TestCase): @ops(op_db) def test_foo(self, device, dtype, op): for sample in op.sample_inputs(device, dtype, requires_grad=False): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic. This PR lets you do this to get very flexible xfail / skips based on op / sample input properties: ```python # NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but # it can be more readable to maintain these somewhere else. These are attempted to be matched in order and # the first one that matches applies, so order can matter. FOO_SKIPS_AND_XFAILS = [ XFailRule( error_type=ValueError, error_mg="2D inputs not supported", op_match_fn=lambda device, op: ( # NB: logic for which ops this rule applies to goes here op.full_name == "add" ), sample_match_fn=lambda device, sample: ( # NB: logic which samples this rule applies to goes here sample.input.dim() == 2 ), # NB: optional rule identifier can help with debugging matched rules name="add_with_2D_inputs_not_supported", ), # NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously # this skips a particular SampleInput instead of xfailing :) SkipRule(...), ... ] class MyTestCase(TestCase): @ops(op_db) @sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS) # NB: the @ops decorator automatically filters out any rules that don't apply to this op def test_foo(self, device, dtype, op): for sample, subtest_ctx in op.sample_inputs( # NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True, # an informative error will be thrown. device, dtype, requires_grad=False, use_subtests=True ): # NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately with subtest_ctx(self): # do some SampleInput-based test logic output = op.op(sample.input, *sample.args, **sample.kwargs) ... ``` More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice. I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`: ```python ... # pre-existing xfail xfail("item"), # new granular xfail using syntactic sugar over the general system xfailIf( "to", lambda sample: ( sample.kwargs["memory_format"] == torch.channels_last ), ), ... ``` Pull Request resolved: pytorch#140443 Approved by: https://github.com/janeyx99, https://github.com/zou3519 ghstack dependencies: pytorch#140160, pytorch#138370
ghstack-source-id: 05da0de Pull Request resolved: pytorch/pytorch#140160
Stack from ghstack (oldest at bottom):
This PR contains several fixes related to non-contiguous NJTs:
lengthsthrough op calls appropriately (see desc of Propagate NJT lengths through op calls #138098)nested_view_from_values_offsets_lengths()instead ofnested_view_from_values_offsets()split/split_with_sizesto allow for optionaldim, matching the ATen signature