Skip to content

Conversation

@benjaminglass1
Copy link
Collaborator

@benjaminglass1 benjaminglass1 commented Sep 11, 2024

Stack from ghstack (oldest at bottom):

Adds lowering for aten.searchsorted. This entails:

  1. Adding support for multi-dimensional bucket tensors to ops.bucketize.
  2. Adding support for striding to ops.bucketize.
  3. Adding support for sorting tensors to ops.bucketize.
  4. Adding a lowering for aten.searchsorted.Tensor.
  5. Adding a basic decomposition for aten.searchsorted.Scalar that calls into the lowering for tensors.
  6. Updating the meta-function for aten.searchsorted to properly check some of the sizing conditions.

Closes #135873

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Differential Revision: D63766514

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135701

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a3c4e98 with merge base 4d3c0fc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 358adf6
Pull Request resolved: #135701
@benjaminglass1 benjaminglass1 self-assigned this Sep 11, 2024
@benjaminglass1
Copy link
Collaborator Author

Benchmark script is attached; outputs are below. Generally, there seem to be no time improvements relative to eager mode for the case without a sorter array, but cases with a sorter array get a ~50% speedup.

[----------------------------- 1-D presorted sequence ----------------------------]
                                                            |  Decomposed  |  Eager
32 threads: -----------------------------------------------------------------------
      buckets shape (32,), values shape (32, 32, 32)        |       37     |      8
      buckets shape (32,), values shape (32, 256, 32)       |       43     |     13
      buckets shape (32,), values shape (256, 32, 32)       |       44     |     13
      buckets shape (32,), values shape (256, 256, 32)      |       91     |     94
      buckets shape (64,), values shape (32, 32, 64)        |       45     |      8
      buckets shape (64,), values shape (32, 256, 64)       |       45     |     29
      buckets shape (64,), values shape (256, 32, 64)       |       44     |     30
      buckets shape (64,), values shape (256, 256, 64)      |      182     |    190
      buckets shape (128,), values shape (32, 32, 128)      |       44     |      8
      buckets shape (128,), values shape (32, 256, 128)     |       49     |     60
      buckets shape (128,), values shape (256, 32, 128)     |       49     |     60
      buckets shape (128,), values shape (256, 256, 128)    |      375     |    392
      buckets shape (256,), values shape (32, 32, 256)      |       44     |     20
      buckets shape (256,), values shape (32, 256, 256)     |      131     |    140
      buckets shape (256,), values shape (256, 32, 256)     |      131     |    150
      buckets shape (256,), values shape (256, 256, 256)    |     1032     |   1000
      buckets shape (512,), values shape (32, 32, 512)      |       52     |     57
      buckets shape (512,), values shape (32, 256, 512)     |      390     |    372
      buckets shape (512,), values shape (256, 32, 512)     |      398     |    378
      buckets shape (512,), values shape (256, 256, 512)    |     3206     |   3269
      buckets shape (1024,), values shape (32, 32, 1024)    |      144     |    160
      buckets shape (1024,), values shape (32, 256, 1024)   |     1128     |   1200
      buckets shape (1024,), values shape (256, 32, 1024)   |     1144     |   1200
      buckets shape (1024,), values shape (256, 256, 1024)  |     9187     |   9000
      buckets shape (2048,), values shape (32, 32, 2048)    |      375     |    372
      buckets shape (2048,), values shape (32, 256, 2048)   |     2992     |   2951
      buckets shape (2048,), values shape (256, 32, 2048)   |     2996     |   2952
      buckets shape (2048,), values shape (256, 256, 2048)  |    24100     |  23780

Times are in microseconds (us).

[----------------------------- 1-D unsorted sequence -----------------------------]
                                                            |  Decomposed  |  Eager
32 threads: -----------------------------------------------------------------------
      buckets shape (32,), values shape (32, 32, 32)        |       38     |     33
      buckets shape (32,), values shape (32, 256, 32)       |       44     |     41
      buckets shape (32,), values shape (256, 32, 32)       |       46     |     36
      buckets shape (32,), values shape (256, 256, 32)      |       92     |    133
      buckets shape (64,), values shape (32, 32, 64)        |       47     |     33
      buckets shape (64,), values shape (32, 256, 64)       |       49     |     67
      buckets shape (64,), values shape (256, 32, 64)       |       45     |     68
      buckets shape (64,), values shape (256, 256, 64)      |      281     |    290
      buckets shape (128,), values shape (32, 32, 128)      |       45     |     41
      buckets shape (128,), values shape (32, 256, 128)     |      108     |    134
      buckets shape (128,), values shape (256, 32, 128)     |      107     |    134
      buckets shape (128,), values shape (256, 256, 128)    |      847     |    790
      buckets shape (256,), values shape (32, 32, 256)      |       45     |     66
      buckets shape (256,), values shape (32, 256, 256)     |      318     |    324
      buckets shape (256,), values shape (256, 32, 256)     |      318     |    324
      buckets shape (256,), values shape (256, 256, 256)    |     1000     |   2285
      buckets shape (512,), values shape (32, 32, 512)      |      115     |    138
      buckets shape (512,), values shape (32, 256, 512)     |      887     |    841
      buckets shape (512,), values shape (256, 32, 512)     |      890     |    844
      buckets shape (512,), values shape (256, 256, 512)    |     3230     |   6400
      buckets shape (1024,), values shape (32, 32, 1024)    |      291     |    305
      buckets shape (1024,), values shape (32, 256, 1024)   |     1000     |   2139
      buckets shape (1024,), values shape (256, 32, 1024)   |     1000     |   2154
      buckets shape (1024,), values shape (256, 256, 1024)  |     9300     |  17000
      buckets shape (2048,), values shape (32, 32, 2048)    |      707     |    680
      buckets shape (2048,), values shape (32, 256, 2048)   |     3021     |   5273
      buckets shape (2048,), values shape (256, 32, 2048)   |     2956     |   5276
      buckets shape (2048,), values shape (256, 256, 2048)  |    23620     |  41800

Times are in microseconds (us).

[--------------------------------- N-D presorted sequence ---------------------------------]
                                                                     |  Decomposed  |  Eager
32 threads: --------------------------------------------------------------------------------
      buckets shape (32, 32, 32), values shape (32, 32, 32)          |       46     |      9
      buckets shape (32, 256, 32), values shape (32, 256, 32)        |       46     |     24
      buckets shape (256, 32, 32), values shape (256, 32, 32)        |       46     |     24
      buckets shape (256, 256, 32), values shape (256, 256, 32)      |      139     |    200
      buckets shape (32, 32, 64), values shape (32, 32, 64)          |       46     |      9
      buckets shape (32, 256, 64), values shape (32, 256, 64)        |       46     |     52
      buckets shape (256, 32, 64), values shape (256, 32, 64)        |       46     |     51
      buckets shape (256, 256, 64), values shape (256, 256, 64)      |      282     |    362
      buckets shape (32, 32, 128), values shape (32, 32, 128)        |       46     |     11
      buckets shape (32, 256, 128), values shape (32, 256, 128)      |       79     |    110
      buckets shape (256, 32, 128), values shape (256, 32, 128)      |       78     |    111
      buckets shape (256, 256, 128), values shape (256, 256, 128)    |      605     |    607
      buckets shape (32, 32, 256), values shape (32, 32, 256)        |       46     |     35
      buckets shape (32, 256, 256), values shape (32, 256, 256)      |      171     |    229
      buckets shape (256, 32, 256), values shape (256, 32, 256)      |      173     |    230
      buckets shape (256, 256, 256), values shape (256, 256, 256)    |     1343     |   1340
      buckets shape (32, 32, 512), values shape (32, 32, 512)        |       60     |     83
      buckets shape (32, 256, 512), values shape (32, 256, 512)      |      445     |    573
      buckets shape (256, 32, 512), values shape (256, 32, 512)      |      445     |    574
      buckets shape (256, 256, 512), values shape (256, 256, 512)    |     3519     |   3533
      buckets shape (32, 32, 1024), values shape (32, 32, 1024)      |      157     |    208
      buckets shape (32, 256, 1024), values shape (32, 256, 1024)    |     1212     |   1200
      buckets shape (256, 32, 1024), values shape (256, 32, 1024)    |     1212     |   1200
      buckets shape (256, 256, 1024), values shape (256, 256, 1024)  |     9751     |   9700
      buckets shape (32, 32, 2048), values shape (32, 32, 2048)      |      400     |    477
      buckets shape (32, 256, 2048), values shape (32, 256, 2048)    |     3168     |   3176
      buckets shape (256, 32, 2048), values shape (256, 32, 2048)    |     3172     |   3183
      buckets shape (256, 256, 2048), values shape (256, 256, 2048)  |    25480     |  25440

Times are in microseconds (us).

[--------------------------------- N-D unsorted sequence ----------------------------------]
                                                                     |  Decomposed  |  Eager
32 threads: --------------------------------------------------------------------------------
      buckets shape (32, 32, 32), values shape (32, 32, 32)          |       48     |     42
      buckets shape (32, 256, 32), values shape (32, 256, 32)        |       49     |     68
      buckets shape (256, 32, 32), values shape (256, 32, 32)        |       48     |     71
      buckets shape (256, 256, 32), values shape (256, 256, 32)      |      200     |    353
      buckets shape (32, 32, 64), values shape (32, 32, 64)          |       49     |     43
      buckets shape (32, 256, 64), values shape (32, 256, 64)        |       60     |    125
      buckets shape (256, 32, 64), values shape (256, 32, 64)        |       61     |    125
      buckets shape (256, 256, 64), values shape (256, 256, 64)      |      428     |    740
      buckets shape (32, 32, 128), values shape (32, 32, 128)        |       49     |     46
      buckets shape (32, 256, 128), values shape (32, 256, 128)      |      140     |    243
      buckets shape (256, 32, 128), values shape (256, 32, 128)      |      140     |    244
      buckets shape (256, 256, 128), values shape (256, 256, 128)    |      610     |   1600
      buckets shape (32, 32, 256), values shape (32, 32, 256)        |       51     |     97
      buckets shape (32, 256, 256), values shape (32, 256, 256)      |      368     |    546
      buckets shape (256, 32, 256), values shape (256, 32, 256)      |      369     |    547
      buckets shape (256, 256, 256), values shape (256, 256, 256)    |     1300     |   3922
      buckets shape (32, 32, 512), values shape (32, 32, 512)        |      130     |    200
      buckets shape (32, 256, 512), values shape (32, 256, 512)      |      573     |   1300
      buckets shape (256, 32, 512), values shape (256, 32, 512)      |      572     |   1300
      buckets shape (256, 256, 512), values shape (256, 256, 512)    |     3528     |   9968
      buckets shape (32, 32, 1024), values shape (32, 32, 1024)      |      337     |    434
      buckets shape (32, 256, 1024), values shape (32, 256, 1024)    |     1000     |   3129
      buckets shape (256, 32, 1024), values shape (256, 32, 1024)    |     1000     |   3133
      buckets shape (256, 256, 1024), values shape (256, 256, 1024)  |     9800     |  24590
      buckets shape (32, 32, 2048), values shape (32, 32, 2048)      |      859     |    960
      buckets shape (32, 256, 2048), values shape (32, 256, 2048)    |     3186     |   7390
      buckets shape (256, 32, 2048), values shape (256, 32, 2048)    |     3163     |   7400
      buckets shape (256, 256, 2048), values shape (256, 256, 2048)  |    25270     |  58800

Times are in microseconds (us).

Benchmark script:

import torch
from torch.testing import make_tensor
from torch.utils.benchmark import Timer, Compare
import torch._inductor.config
from itertools import product
from functools import partial
import torch._dynamo.config

torch._dynamo.config.cache_size_limit = 1000
torch._inductor.config.force_disable_caches = True
torch._inductor.config.triton.cudagraphs = False

benchmark_name = "searchsorted"
XY_sizes = [32, 256]
Z_sizes = [32, 64, 128, 256, 512, 1024, 2048]


def gen_inputs():
    make_arg = partial(make_tensor, dtype=torch.float32, device="cuda:1")

    for Z, X, Y in product(Z_sizes, XY_sizes, XY_sizes):
        unsorted_seq = make_arg(Z)
        sorted_seq, sorting_indices = torch.sort(unsorted_seq)
        values = make_arg((X, Y, Z))

        # 1-D sorted case
        yield sorted_seq, values, None
        # 1-D unsorted case
        yield unsorted_seq, values, sorting_indices

    for Z, X, Y in product(Z_sizes, XY_sizes, XY_sizes):
        unsorted_seq = make_arg((X, Y, Z))
        sorted_seq, sorting_indices = torch.sort(unsorted_seq)
        values = make_arg((X, Y, Z))

        # N-D sorted case
        yield sorted_seq, values, None
        # N-D unsorted case
        yield unsorted_seq, values, sorting_indices


def benchmark(label, f, args):
    dim_label = "1-D" if len(args[0].shape) == 1 else "N-D"
    sorting_label = "presorted" if args[2] is None else "unsorted"

    # Warm up the timed function kernel.
    f(*args)
    f(*args)

    return Timer(
        "f(*args)",
        globals={"f": f, "args": args},
        label=f"{dim_label} {sorting_label} sequence",
        description=label,
        sub_label=f"buckets shape {tuple(args[0].shape)}, values shape {tuple(args[1].shape)}",
        num_threads=torch.get_num_threads(),
    ).blocked_autorange()


def compare(sorted_sequence, values, sorter=None):
    def f(s, v, sort):
        return torch.ops.aten.searchsorted(s, v, sorter=sort)

    f_compile = torch.compile(f)

    args = (sorted_sequence, values, sorter)
    yield benchmark("Decomposed", f_compile, args)
    yield benchmark("Eager", f, args)


c = Compare([r for args in gen_inputs() for r in compare(*args)])
c.trim_significant_figures()
c.print()

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 6f5a83e
Pull Request resolved: #135701
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: eaa4c44
Pull Request resolved: #135701
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 144eddf
Pull Request resolved: #135701
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 181e29e
Pull Request resolved: #135701
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 8eae6f4
Pull Request resolved: #135701
@benjaminglass1 benjaminglass1 marked this pull request as draft September 11, 2024 20:30
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 11, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 055bbff
Pull Request resolved: #135701
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 12, 2024
In order to support searchsorted, the underlying bucketize op had two
make two changes:

1. Supporting multi-dimensional buckets (done by passing an additional
   index parameter specifying the length of a single set of buckets).
2. Supporting sorter tensors (mostly implemented within the
   triton_helpers bucketize helper).

To implement this lowering, an extremely basic inductor prim and decomp
are defined, which enable us to make the sorter array a positional
argument rather than a keyword argument (lowering do not support)

ghstack-source-id: 65a6612
Pull Request resolved: #135701
@benjaminglass1 benjaminglass1 marked this pull request as ready for review September 12, 2024 16:13
Copy link
Collaborator

@amjames amjames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor things that I caught, otherwise looks good!

[ghstack-poisoned]
[ghstack-poisoned]
@benjaminglass1 benjaminglass1 added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 27, 2024
[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Sep 28, 2024
This entails:

1. Adding support for multi-dimensional bucket tensors to ops.bucketize.
2. Adding support for striding to ops.bucketize.
3. Adding support for sorting tensors to ops.bucketize.
4. Adding a lowering for aten.searchsorted.Tensor.
5. Adding a basic decomposition for aten.searchsorted.Scalar that calls
   into the lowering for tensors.
6. Updating the meta-function for aten.searchsorted to properly check
   some of the sizing conditions.

ghstack-source-id: 9256786
Pull Request resolved: #135701
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already approved by me - differing to @davidberard98 on resolving comments he requested changes for

Copy link
Contributor

@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delay - left a few comments, but generally LGTM at a high level.

However, since this does change some APIs that are used internally at Meta, I think I'll need to land this manually. Once the PR is ready (any PRs that it depends on are landed, plus any other comments resolved), could you let me know and I can make sure it gets landed with all the necessary Meta-only changes?

Adds lowering for `aten.searchsorted`. This entails:

1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.

Closes #135873

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Oct 1, 2024
This entails:

1. Adding support for multi-dimensional bucket tensors to ops.bucketize.
2. Adding support for striding to ops.bucketize.
3. Adding support for sorting tensors to ops.bucketize.
4. Adding a lowering for aten.searchsorted.Tensor.
5. Adding a basic decomposition for aten.searchsorted.Scalar that calls
   into the lowering for tensors.
6. Updating the meta-function for aten.searchsorted to properly check
   some of the sizing conditions.

ghstack-source-id: 7698870
Pull Request resolved: #135701
Adds lowering for `aten.searchsorted`. This entails:

1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.

Closes #135873

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Oct 1, 2024
This entails:

1. Adding support for multi-dimensional bucket tensors to ops.bucketize.
2. Adding support for striding to ops.bucketize.
3. Adding support for sorting tensors to ops.bucketize.
4. Adding a lowering for aten.searchsorted.Tensor.
5. Adding a basic decomposition for aten.searchsorted.Scalar that calls
   into the lowering for tensors.
6. Updating the meta-function for aten.searchsorted to properly check
   some of the sizing conditions.

ghstack-source-id: 61db6d5
Pull Request resolved: #135701
Adds lowering for `aten.searchsorted`. This entails:

1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.

Closes #135873

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Oct 2, 2024
This entails:

1. Adding support for multi-dimensional bucket tensors to ops.bucketize.
2. Adding support for striding to ops.bucketize.
3. Adding support for sorting tensors to ops.bucketize.
4. Adding a lowering for aten.searchsorted.Tensor.
5. Adding a basic decomposition for aten.searchsorted.Scalar that calls
   into the lowering for tensors.
6. Updating the meta-function for aten.searchsorted to properly check
   some of the sizing conditions.

ghstack-source-id: 748545b
Pull Request resolved: #135701
@davidberard98
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@davidberard98
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

[ghstack-poisoned]
benjaminglass1 added a commit that referenced this pull request Oct 3, 2024
This entails:

1. Adding support for multi-dimensional bucket tensors to ops.bucketize.
2. Adding support for striding to ops.bucketize.
3. Adding support for sorting tensors to ops.bucketize.
4. Adding a lowering for aten.searchsorted.Tensor.
5. Adding a basic decomposition for aten.searchsorted.Scalar that calls
   into the lowering for tensors.
6. Updating the meta-function for aten.searchsorted to properly check
   some of the sizing conditions.

ghstack-source-id: 0023733
Pull Request resolved: #135701
[ghstack-poisoned]
@davidberard98
Copy link
Contributor

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 18, 2024
Add an additional check that scalars wrapped to 0-D tensors by dynamo are actually 0-D.  This fixes a bug where a 1-D tensor was mistakenly converted to a scalar value rather than passed as a pointer.

Pull Request resolved: #137303
Approved by: https://github.com/eellison
ghstack dependencies: #135701
@github-actions github-actions bot deleted the gh/benjaminglass1/8/head branch November 6, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants