-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Motivation
Cutlass is an efficient template library for compute-heavy GPU operations like Gemm, Conv and others. It has decent support for H100. Some important kernels (e.g. Flash attention v2, XFormer attention) are implemented based on Cutlass, too. It would be a good complement to Triton to generate Inductor fused kernels for some compute heavy operations.
Proposal
The proposal is to add Cutlass as an alternative backend of Inductor, as demonstrated in the prototype PR (#106607). As shown in the PR:
-
In torch/_inductor/config.py, "Cutlass" could be configured as one of the Inductor max-autotune backend through
max_autotune_gemm_backends. (In the future, we may want to extend it to cover other cases beyond gemm.) This option can be set via options intorch.compile()(https://pytorch.org/docs/stable/generated/torch.compile.html#torch-compile). -
Once "Cutlass" is added as a candidate, max-autotune will also tune and select cutlass kernels, together with Aten and Triton kernels.
-
We'll also utilize Cutlass epilogue visitor to support flexible gemm and epilogue fusions in later PRs. More features will come in the future.
Release / Dependency
-
Pytorch release: to properly release cutlass together with Pytorch, we need to find a way to pack cutlass into Pytorch package distribution. This includes both C++ header files (third_party/cutlass/include) as well as Python scripts and modules (third_party/cutlass/tools/library/scripts).
-
NVCC dependency: The prototype implementation relies on NVCC. We need to figure out ways to switch to NVRTC. Cutlass team has a proposal to use NVCC to do pre-processing at Pytorch packaging stage, and make end-users only rely on NVRTC. This needs to be discussed in details.
-
Meta internal deployment: We need to figure out a way to properly deploy cutlass dependency in Meta internal env.
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @wconstab @Xia-Weiwen @ngimel @atalman @malfet @albanD @ptrblck @seemethere @jansel @gottbrath
cc @jackkosaian @mnicely @hwu36
Alternatives
No response
Additional context
No response