Skip to content

Conversation

@exclamaforte
Copy link
Contributor

@exclamaforte exclamaforte commented Aug 22, 2025

Summary

Adds a subgraph decomposition for addmm and mm that performs well on large K compared to M and N, and functions well as an alternative to split-k on AMD (transposed only), which does not support AMD currently.

Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:

  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))

is a lot slower than:

  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))

This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:

Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

Results

The largest shape gain, 256,197951,512, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:

addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...

Compared to the non-transposed autotune:

addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916

It seems to perform really well for high values of K vs N and M.
Testing this hypothesis with some custom shapes:

Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

Conclusion

Contiguous Subgraph Decompositionseems worthwhile for mm and addmm, but not bmm, and has a very large improvment on low M, low N, and high K shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

Test Plan:

New unit tests.

Differential Revision: D80771648

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161241

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit e05eae3 with merge base a99d8d3 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D80771648

@facebook-github-bot
Copy link
Contributor

@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 22, 2025
@exclamaforte exclamaforte force-pushed the export-D80771648 branch 3 times, most recently from 4a42dca to 1bb3352 Compare August 26, 2025 17:25
@exclamaforte exclamaforte requested a review from eellison August 26, 2025 22:18
@facebook-github-bot
Copy link
Contributor

@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648.

@exclamaforte exclamaforte changed the title Contiguous transpose subgraph decomposition Contiguous subgraph decomposition Aug 26, 2025
@exclamaforte exclamaforte requested review from jansel and removed request for eellison and jansel August 27, 2025 04:36
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good ! cc @PaulZhang12

@coconutruben
Copy link
Contributor

@exclamaforte can you do this in the same way we did #161026 ? this will allow us to integrate with everything else nicely

  1. make a contigous_mm template
  2. add a template heuristic for it, can just yield a single dict with no kwargs

@exclamaforte
Copy link
Contributor Author

@pytorchbot merge

@facebook-github-bot
Copy link
Contributor

@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648.

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@pytorchmergebot
Copy link
Collaborator

@exclamaforte your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Sep 4, 2025
@jeffdaily jeffdaily added the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Sep 4, 2025
@facebook-github-bot
Copy link
Contributor

@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648.

@exclamaforte
Copy link
Contributor Author

exclamaforte commented Sep 4, 2025

Verified that I fixed this test now on an AMD devserver

@exclamaforte
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648.

@exclamaforte
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
## Summary

Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.

## Background

On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
  ))
```
is a lot slower than:
```
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
  ))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.

## Data

I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
  addmm_16448x512x2048: +0.14%
  addmm_128x2048x2048: +0.01%
  addmm_128x768x1000: +0.75%
  addmm_12672x3072x768: +1.08%
  addmm_512x768x32000: +0.62%
  addmm_12608x384x384: +0.00%
  addmm_4160x1024x4096: +0.90%
  addmm_16x768x2: +0.56%
  addmm_12608x3072x768: +0.09%
  addmm_64x4096x1000: +2.77%
  addmm_256x1024x512: +1.99%
  addmm_30x256x256: +1.12%
  addmm_100480x128x384: +0.91%
  addmm_6400x2048x512: +0.25%
  addmm_61568x1024x256: +0.08%
  addmm_1x768x768: +0.93%
  addmm_12544x384x384: +0.19%
  addmm_128x512x1000: +0.77%
  addmm_2048x128x128: +1.32%
  addmm_128x3072x1000: +0.24%
  addmm_7936x512x2048: +0.07%
  addmm_8192x512x2048: +0.33%
  addmm_64x1024x1000: +1.43%
  addmm_128x2304x1000: +0.01%
  addmm_32768x256x2: +0.75%
  addmm_64x384x1152: +0.79%
  addmm_64x640x1000: +0.01%
  addmm_100480x128x128: +0.87%
  addmm_1152x3072x768: +1.13%
  addmm_8192x256x2048: +1.40%
  addmm_4096x128x768: +0.01%
  addmm_128x2560x1000: +0.01%
  addmm_12544x2048x512: +0.43%
  addmm_200704x24x96: +0.14%
  addmm_8448x512x2048: +0.96%
  addmm_50176x256x1024: +0.62%
  addmm_4160x4096x1024: +0.22%
  addmm_4096x768x768: +0.32%
  addmm_220x2048x512: +0.56%
  addmm_8x2048x1000: +1.12%
  addmm_256x197951x512: +26.99%
  addmm_401536x64x192: +0.60%
  addmm_2040x2048x512: +0.47%
  addmm_512x1024x256: +1.32%
  addmm_128x4096x1000: +1.67%
  addmm_12672x768x768: +0.34%
  addmm_128x368x1000: +0.77%
  addmm_96x1280x1000: +0.01%
  addmm_12544x512x2048: +0.41%
  addmm_6272x320x1280: +0.76%
  addmm_12544x3072x768: +0.09%
  addmm_64x384x1000: +0.39%
mm improvements when best:
  mm_200704x128x512: +1.29%
  mm_663552x16x16: +0.80%
  mm_4096x768x768: +0.51%
  mm_131072x64x31: +0.24%
  mm_12544x1152x384: +0.11%
  mm_128x2048x2: +0.46%
  mm_262144x16x23: +0.62%
  mm_50176x576x192: +0.37%
  mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%

Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)

Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%

```

## Results

The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```

It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
  addmm_128x16384x128: +0.18%
  addmm_128x262144x256: +38.24%
  addmm_128x200000x512: +14.76%
  addmm_256x800000x128: +0.06%
  addmm_131072x128x256: +0.27%
  addmm_128x256x131072: +0.25%
  addmm_2048x200000x64: +12.45%
mm improvements when best:
  mm_128x16384x128: +0.18%
  mm_128x262144x256: +38.05%
  mm_128x200000x512: +9.47%
  mm_256x800000x128: +0.99%
  mm_512x6400000x256: +3.17%
  mm_524288x64x64: +0.29%
  mm_2048x200000x64: +11.19%
  mm_8192x1000000x256: +34.14%
  mm_128x4096x100000: +0.40%
  mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================

Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%

Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%

```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.

Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866

## Test Plan:
New unit tests.

Differential Revision: D80771648

Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
@github-actions github-actions bot deleted the export-D80771648 branch October 5, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants