Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Aug 2, 2024

Stack from ghstack (oldest at bottom):

When we are autotuning matmuls the aten.mm and the triton template choices take in an externally allocated tensor that can be a view into a pre-planned aten.cat. So long as the output shape and stride of the matmul matches the slice of the cat we're planning, we can realize the mm directly into the cat.

Discussion for reviewers:

It feels a little bit odd that in the existing code we set the output of aten.mm as FlexibleLayout. While is this correct, it might lead to passing non performant output strides to cublas.. I guess this is better than a copy ? Not sure. We could also introduce a Layout that denotes a Fixed shape and stride which we control allocation

class AllocatedFixedLayout(FixedLayout)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132554

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 914f5ab with merge base d1b87e2 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Aug 2, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Aug 2, 2024
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Aug 2, 2024
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at this carefully yet, but will this work with all triton templates? thinking about flexattention here

@eellison
Copy link
Contributor Author

eellison commented Aug 2, 2024

Ugh, it could, but it doesn't right now because I only implemented this for MultiTemplateBuffer and flex_attention has input_gen_fns which is NYI for MultiTemplate. But it could without much difficulty.

But I wanted to resolve the current, non-max-autotune handling of mms first then can handle it. Like with above, I think returning FlexibleLayout for aten.mm is misleading/buggy. Should not rely on that for cat planning.

Options:

  • make the aten.mm return FixedLayout, and check for external kernel alloc in concat planning
  • introduce AllocatedFixedLayout

@Chillee
Copy link
Collaborator

Chillee commented Aug 2, 2024

It feels a little bit odd that in the existing code we set the output of aten.mm as FlexibleLayout even though its shape and stride are fixed.

I don't actually understand this? Isn't aten.mm codegened with an out parameter? So its stride isn't actually fixed?

@eellison
Copy link
Contributor Author

eellison commented Aug 2, 2024

Hmm, maybe, I don't know what would actually happen if you pass cublas a weird output stride.

I think you are correct that it would work but we also dont want to pass in a transposed output to a cublas kernel and get a bunch of discontiguous writes.

Do we actually want the output strides of mms to be flexible ?

@eellison
Copy link
Contributor Author

eellison commented Aug 2, 2024

Hmm, at least this was about equal:

import torch
import triton
from torch._inductor.select_algorithm import extern_kernels

torch.set_default_device('cuda')

inps = [torch.rand([4096, 4096], dtype=torch.float16) for _ in range(2)]
out1 = inps[0].clone()
out2 = inps[0].clone().T

print(triton.testing.do_bench(lambda: extern_kernels.mm(inps[0], inps[1], out=out1)))
print(triton.testing.do_bench(lambda: extern_kernels.mm(inps[0], inps[1], out=out2)))

Similarly for FlexAttention - if we just change the layout to be FlexibleLayout, this would work today, but are you okay with the output strides potentially being non contiguous ?

@Chillee
Copy link
Collaborator

Chillee commented Aug 3, 2024

are you okay with the output strides potentially being non contiguous

Yeah, i think so. Well, I'd definitely want them to be "contiguous enough".

When we are autotuning matmuls the aten.mm and the triton template choices take in an externally allocated tensor that can be a view into a pre-planned aten.cat. So long as the output shape and stride of the matmul matches the slice of the cat we're planning, we can realize the mm directly into the cat. 

Discussion for reviewers:

It feels a little bit odd that in the existing code we set the output of aten.mm as [FlexibleLayout](https://github.com/pytorch/pytorch/blob/bcac71517c461765b4fa9efccc6f1a5a475c3544/torch/_inductor/kernel/mm.py#L156). While is this correct, it might lead to passing non performant output strides to cublas.. I guess this is better than a copy ? Not sure. We could also introduce a Layout that denotes a Fixed shape and stride which we control allocation

```
class AllocatedFixedLayout(FixedLayout)
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 3, 2024
@eellison eellison added the topic: not user facing topic category label Oct 4, 2024
@eellison
Copy link
Contributor Author

eellison commented Oct 4, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/eellison/686/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/132554)

pytorchmergebot pushed a commit that referenced this pull request Oct 4, 2024
When we are autotuning matmuls the aten.mm and the triton template choices take in an externally allocated tensor that can be a view into a pre-planned aten.cat. So long as the output shape and stride of the matmul matches the slice of the cat we're planning, we can realize the mm directly into the cat. 

Discussion for reviewers:

It feels a little bit odd that in the existing code we set the output of aten.mm as [FlexibleLayout](https://github.com/pytorch/pytorch/blob/bcac71517c461765b4fa9efccc6f1a5a475c3544/torch/_inductor/kernel/mm.py#L156). While is this correct, it might lead to passing non performant output strides to cublas.. I guess this is better than a copy ? Not sure. We could also introduce a Layout that denotes a Fixed shape and stride which we control allocation

```
class AllocatedFixedLayout(FixedLayout)
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 7, 2024
@eellison
Copy link
Contributor Author

eellison commented Oct 7, 2024

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 7, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: inductor / linux-jammy-cpu-py3.9-gcc11-inductor / test (cpu_inductor_freezing_avx2_timm, 2, 2, lf.linux.10xlarge.avx2)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn
Copy link
Contributor

huydhn commented Oct 8, 2024

@pytorchbot revert -m 'Sorry for reverting your change but I think it is failing on ROCm' -c nosignal

inductor/test_max_autotune.py::TestMaxAutotune::test_conv_cat GH job link HUD commit link

@huydhn huydhn added the ciflow/rocm Trigger "default" config CI on ROCm label Oct 8, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Oct 8, 2024
…32554)"

This reverts commit d558ec0.

Reverted #132554 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it is failing on ROCm ([comment](#132554 (comment)))
@pytorchmergebot
Copy link
Collaborator

@eellison your PR has been successfully reverted.

When we are autotuning matmuls the aten.mm and the triton template choices take in an externally allocated tensor that can be a view into a pre-planned aten.cat. So long as the output shape and stride of the matmul matches the slice of the cat we're planning, we can realize the mm directly into the cat. 

Discussion for reviewers:

It feels a little bit odd that in the existing code we set the output of aten.mm as [FlexibleLayout](https://github.com/pytorch/pytorch/blob/bcac71517c461765b4fa9efccc6f1a5a475c3544/torch/_inductor/kernel/mm.py#L156). While is this correct, it might lead to passing non performant output strides to cublas.. I guess this is better than a copy ? Not sure. We could also introduce a Layout that denotes a Fixed shape and stride which we control allocation

```
class AllocatedFixedLayout(FixedLayout)
```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
eellison added a commit that referenced this pull request Oct 8, 2024
@eellison
Copy link
Contributor Author

eellison commented Oct 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants