-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add lowering for aten.searchsorted #135701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add lowering for aten.searchsorted #135701
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135701
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a3c4e98 with merge base 4d3c0fc ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 358adf6 Pull Request resolved: #135701
|
Benchmark script is attached; outputs are below. Generally, there seem to be no time improvements relative to eager mode for the case without a sorter array, but cases with a sorter array get a ~50% speedup. Benchmark script: import torch
from torch.testing import make_tensor
from torch.utils.benchmark import Timer, Compare
import torch._inductor.config
from itertools import product
from functools import partial
import torch._dynamo.config
torch._dynamo.config.cache_size_limit = 1000
torch._inductor.config.force_disable_caches = True
torch._inductor.config.triton.cudagraphs = False
benchmark_name = "searchsorted"
XY_sizes = [32, 256]
Z_sizes = [32, 64, 128, 256, 512, 1024, 2048]
def gen_inputs():
make_arg = partial(make_tensor, dtype=torch.float32, device="cuda:1")
for Z, X, Y in product(Z_sizes, XY_sizes, XY_sizes):
unsorted_seq = make_arg(Z)
sorted_seq, sorting_indices = torch.sort(unsorted_seq)
values = make_arg((X, Y, Z))
# 1-D sorted case
yield sorted_seq, values, None
# 1-D unsorted case
yield unsorted_seq, values, sorting_indices
for Z, X, Y in product(Z_sizes, XY_sizes, XY_sizes):
unsorted_seq = make_arg((X, Y, Z))
sorted_seq, sorting_indices = torch.sort(unsorted_seq)
values = make_arg((X, Y, Z))
# N-D sorted case
yield sorted_seq, values, None
# N-D unsorted case
yield unsorted_seq, values, sorting_indices
def benchmark(label, f, args):
dim_label = "1-D" if len(args[0].shape) == 1 else "N-D"
sorting_label = "presorted" if args[2] is None else "unsorted"
# Warm up the timed function kernel.
f(*args)
f(*args)
return Timer(
"f(*args)",
globals={"f": f, "args": args},
label=f"{dim_label} {sorting_label} sequence",
description=label,
sub_label=f"buckets shape {tuple(args[0].shape)}, values shape {tuple(args[1].shape)}",
num_threads=torch.get_num_threads(),
).blocked_autorange()
def compare(sorted_sequence, values, sorter=None):
def f(s, v, sort):
return torch.ops.aten.searchsorted(s, v, sorter=sort)
f_compile = torch.compile(f)
args = (sorted_sequence, values, sorter)
yield benchmark("Decomposed", f_compile, args)
yield benchmark("Eager", f, args)
c = Compare([r for args in gen_inputs() for r in compare(*args)])
c.trim_significant_figures()
c.print() |
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 6f5a83e Pull Request resolved: #135701
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: eaa4c44 Pull Request resolved: #135701
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 144eddf Pull Request resolved: #135701
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 181e29e Pull Request resolved: #135701
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 8eae6f4 Pull Request resolved: #135701
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 055bbff Pull Request resolved: #135701
In order to support searchsorted, the underlying bucketize op had two make two changes: 1. Supporting multi-dimensional buckets (done by passing an additional index parameter specifying the length of a single set of buckets). 2. Supporting sorter tensors (mostly implemented within the triton_helpers bucketize helper). To implement this lowering, an extremely basic inductor prim and decomp are defined, which enable us to make the sorter array a positional argument rather than a keyword argument (lowering do not support) ghstack-source-id: 65a6612 Pull Request resolved: #135701
amjames
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor things that I caught, otherwise looks good!
This entails: 1. Adding support for multi-dimensional bucket tensors to ops.bucketize. 2. Adding support for striding to ops.bucketize. 3. Adding support for sorting tensors to ops.bucketize. 4. Adding a lowering for aten.searchsorted.Tensor. 5. Adding a basic decomposition for aten.searchsorted.Scalar that calls into the lowering for tensors. 6. Updating the meta-function for aten.searchsorted to properly check some of the sizing conditions. ghstack-source-id: 9256786 Pull Request resolved: #135701
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already approved by me - differing to @davidberard98 on resolving comments he requested changes for
davidberard98
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for the delay - left a few comments, but generally LGTM at a high level.
However, since this does change some APIs that are used internally at Meta, I think I'll need to land this manually. Once the PR is ready (any PRs that it depends on are landed, plus any other comments resolved), could you let me know and I can make sure it gets landed with all the necessary Meta-only changes?
Adds lowering for `aten.searchsorted`. This entails: 1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`. 2. Adding support for striding to `ops.bucketize`. 3. Adding support for sorting tensors to `ops.bucketize`. 4. Adding a lowering for `aten.searchsorted.Tensor`. 5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors. 6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions. Closes #135873 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This entails: 1. Adding support for multi-dimensional bucket tensors to ops.bucketize. 2. Adding support for striding to ops.bucketize. 3. Adding support for sorting tensors to ops.bucketize. 4. Adding a lowering for aten.searchsorted.Tensor. 5. Adding a basic decomposition for aten.searchsorted.Scalar that calls into the lowering for tensors. 6. Updating the meta-function for aten.searchsorted to properly check some of the sizing conditions. ghstack-source-id: 7698870 Pull Request resolved: #135701
Adds lowering for `aten.searchsorted`. This entails: 1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`. 2. Adding support for striding to `ops.bucketize`. 3. Adding support for sorting tensors to `ops.bucketize`. 4. Adding a lowering for `aten.searchsorted.Tensor`. 5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors. 6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions. Closes #135873 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This entails: 1. Adding support for multi-dimensional bucket tensors to ops.bucketize. 2. Adding support for striding to ops.bucketize. 3. Adding support for sorting tensors to ops.bucketize. 4. Adding a lowering for aten.searchsorted.Tensor. 5. Adding a basic decomposition for aten.searchsorted.Scalar that calls into the lowering for tensors. 6. Updating the meta-function for aten.searchsorted to properly check some of the sizing conditions. ghstack-source-id: 61db6d5 Pull Request resolved: #135701
Adds lowering for `aten.searchsorted`. This entails: 1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`. 2. Adding support for striding to `ops.bucketize`. 3. Adding support for sorting tensors to `ops.bucketize`. 4. Adding a lowering for `aten.searchsorted.Tensor`. 5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors. 6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions. Closes #135873 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This entails: 1. Adding support for multi-dimensional bucket tensors to ops.bucketize. 2. Adding support for striding to ops.bucketize. 3. Adding support for sorting tensors to ops.bucketize. 4. Adding a lowering for aten.searchsorted.Tensor. 5. Adding a basic decomposition for aten.searchsorted.Scalar that calls into the lowering for tensors. 6. Updating the meta-function for aten.searchsorted to properly check some of the sizing conditions. ghstack-source-id: 748545b Pull Request resolved: #135701
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This entails: 1. Adding support for multi-dimensional bucket tensors to ops.bucketize. 2. Adding support for striding to ops.bucketize. 3. Adding support for sorting tensors to ops.bucketize. 4. Adding a lowering for aten.searchsorted.Tensor. 5. Adding a basic decomposition for aten.searchsorted.Scalar that calls into the lowering for tensors. 6. Updating the meta-function for aten.searchsorted to properly check some of the sizing conditions. ghstack-source-id: 0023733 Pull Request resolved: #135701
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add an additional check that scalars wrapped to 0-D tensors by dynamo are actually 0-D. This fixes a bug where a 1-D tensor was mistakenly converted to a scalar value rather than passed as a pointer. Pull Request resolved: #137303 Approved by: https://github.com/eellison ghstack dependencies: #135701
Stack from ghstack (oldest at bottom):
Adds lowering for
aten.searchsorted. This entails:ops.bucketize.ops.bucketize.ops.bucketize.aten.searchsorted.Tensor.aten.searchsorted.Scalarthat calls into the lowering for tensors.aten.searchsortedto properly check some of the sizing conditions.Closes #135873
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang
Differential Revision: D63766514