Skip to content

Conversation

@aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Dec 5, 2024

This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the tuned_mm and tuned_addmm lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the use_triton_tma_template helper):

  1. The min. hardware and Triton version requirements are met for the TMA support.

  2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).

  3. The config.triton.enable_persistent_tma_matmul is set to True.

Additional notes:

  1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.

  2. The current Triton TMA API (experimental_device_tensormap_create2d) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.

  3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).

  4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in [Pipeliner] Multi-buffer TMA descriptors triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above.

  5. The configs for the new Triton template in persistent_mm_kernel_configs are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.

  6. This PR is rebased onto and unifies with two related PRs landed previously: Adding lowering to persistent-tma device kernel for _scaled_mm #142045 (some infra unification with the persistent+TMA template for _scaled_mm) and Prologue Fusion #134532 (add possibility to disable prolog fusion for selected choices).

  7. The current Triton TMA API only supports 1D and 2D descriptors (even after [Pipeliner] Multi-buffer TMA descriptors triton-lang/triton#5290, see here). For now, this blocks adding persistent+TMA template for torch.bmm.

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142101

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3bc5de0 with merge base e0bdae7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

aakhundov added a commit that referenced this pull request Dec 5, 2024
@aakhundov aakhundov marked this pull request as draft December 5, 2024 03:39
@aakhundov aakhundov added the topic: not user facing topic category label Dec 5, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 5, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 10, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 10, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 10, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 11, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 12, 2024
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 12, 2024
ghstack-source-id: 7f86c28
Pull Request resolved: #142101
[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 13, 2024
ghstack-source-id: 16cd91f
Pull Request resolved: #142101
@aakhundov aakhundov changed the title [WIP] Add persistent+TMA version of Triton mm and addmm Add persistent+TMA version of Triton mm and addmm Dec 13, 2024
@aakhundov aakhundov marked this pull request as ready for review December 13, 2024 18:31
Comment on lines +2648 to +2658
# For prologue fusion we check if the underlying template of the choice
# supports all allowed prologue inputs. If not, we skip this choice in
# the fusion benchmark.
# TODO: Remove this check after all Triton templates support prologue fusion.
# Currently, persistent+TMA Triton template does not due to the TMA-based loads.
if (
not epilogue_fusion
and hasattr(choice, "allowed_prologue_inps")
and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps
):
continue
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @eellison: this is to selectively skip choices not supporting prologue fusion (like currently the choices from the persistent+TMA template).

Comment on lines +339 to +349
scaled_persistent_mm_kernel_configs = [
{"config": (128, 128, 64, 3, 8), "cond": True},
{"config": (128, 128, 128, 3, 8), "cond": True},
{"config": (128, 128, 128, 4, 8), "cond": True},
{"config": (128, 128, 128, 4, 4), "cond": True},
{"config": (128, 128, 128, 3, 4), "cond": True},
{"config": (128, 128, 128, 5, 4), "cond": True},
{"config": (128, 128, 128, 5, 8), "cond": True},
{"config": (128, 128, 128, 6, 8), "cond": True},
{"config": (128, 128, 64, 4, 8), "cond": True},
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @drisspg: separated these configs used in _scaled_mm persistent+TMA template-based lowering, as ~half of them OOMs on SMEM for 2-byte dtypes. Kept the ones that don't in persistent_mm_kernel_configs.

w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, bad mistake on my end thanks for fixing!

# based on triton.ops.matmul
start_pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not cdiv for these too?

workspace_arg=get_workspace_arg(
kwargs["NUM_SMS"], mat_a.get_device()
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually only need 2, if you care to update as well, my top of stack had the C stores but that was buggy anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Helped me find a nasty bug in the new template code, too.

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! will let Elias comment on the prologue stuff

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 13, 2024
ghstack-source-id: 8008755
Pull Request resolved: #142101
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great !

# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in a follow up we can dedup with the scaled version cc @drisspg

Copy link
Contributor

@drisspg drisspg Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah should/want to do this

[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
A.dtype.element_ty,
)
b = tl._experimental_descriptor_load(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont see any mask here. i guess it doesnt support k not divisible by k block ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TMA does support k not divisible by k block (given that k * dtype.itemsize % 16 == 0). Masking happens in the HW doing TMA, with the OOB values set to zero.

c
for c in choices
if re.search(
config.test_configs.autotune_choice_name_regex,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is it worth functools.lru_caching re.compile() and using that ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine these regexes being used exclusively for testing (as in: not in prod / with real models) and being relatively simple (prob a substring or a few separated by |). So not sure how much it's worth precompiling them. Happy to add if you feel it's worth, though.

@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 14, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x b85dd6fa405d8b7019e83646b62b21f92e27999a returned non-zero exit code 1

Auto-merging test/inductor/test_max_autotune.py
CONFLICT (content): Merge conflict in test/inductor/test_max_autotune.py
Auto-merging torch/_inductor/kernel/mm.py
Auto-merging torch/_inductor/kernel/mm_common.py
Auto-merging torch/_inductor/scheduler.py
Auto-merging torch/_inductor/select_algorithm.py
Auto-merging torch/_inductor/utils.py
error: could not apply b85dd6fa405... Add persistent+TMA version of Triton mm and addmm
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Dec 15, 2024
ghstack-source-id: 2d14d8f
Pull Request resolved: #142101
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@aakhundov aakhundov self-assigned this Dec 17, 2024
@github-actions github-actions bot deleted the gh/aakhundov/19/head branch January 18, 2025 02:02
)
if ki == k_tiles - 1:
# rematerialize rm and rn to save registers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you determine if not doing this results in spill/fills?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants