-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Motivation
torch.compile provides the "max-autotune" mode. For CUDA, the inductor backend leverages online benchmark results to select the best-performing kernels from various options, including ATen kernels and template-based kernels implemented with Triton and CUTLASS. These kernels are primarily designed to accelerate GEMM-related operations. However, for CPU, this "max-autotune" mechanism is not yet supported, and only ATen kernels are currently utilized.
This RFC proposes the introduction of similar template-based code generation support for GEMM-related operations on CPUs, implemented with C++ and activated through the "max-autotune" mode of torch.compile. By utilizing the autotuning mechanism of Inductor, users are expected to achieve enhanced performance for GEMM-related operations beyond the capabilities of ATen-based implementations.
Approaches
At a high level, the autotuning and template infrastructure from CUDA is mature enough to be adapted for CPU usage. We plan to extend the existing autotuning code to support CPU and develop the C++ template abstraction by referencing the CUTLASS template counterpart. Additionally, CPU-specific challenges such as thread decomposition, data layout arrangement (e.g., weight prepacking), and data blocking at various memory hierarchy levels for optimal performance need to be addressed. Based on our previous experiences, we employ a two-level abstraction to implement GEMMs: an outer loop that manages thread decomposition and cache blocking, and an inner micro-kernel that handles register blocking and various CPU architecture-specific optimizations. This approach allows for flexible performance tuning at multiple levels and direct utilization of low-level CPU hardware acceleration.
Key Components
- Autotune Infrastructure for CPU: Generalizing and extending BenchmarkRequest with CPU support and Cpp module loader.
- Cpp Template Infrastructure: Involving similar template abstractions as the CUTLASS template, such as CppTemplate, CppTemplateKernel, CppTemplateBuffer. The MicroGemm micro-kernel abstraction can be used by Cpp GEMM templates.
- Micro Kernel Templates: Responsible for register blocking, instruction selection, and other CPU architecture-specific optimizations.
- Cpp Templates: Including various GEMM-related Cpp templates (single GEMM, weight-only quantized GEMM, attention, MLP, etc.) that are responsible for thread decomposition, cache blocking, and outer-loop scheduling calling into micro-kernels. Packed GEMM support included.
- Epilogue Fusion: This would involve support from Cpp templates, micro-kernel templates, and Cpp kernels.
Task Breakdowns
- 1. Autotune Infrastructure for CPU ([inductor] autotune benchmark support for cpu #125159)
- 2. Cpp Template Infrastructure ([inductor][cpp] GEMM template (infra and fp32) #124021)
- Micro Kernel Templates
- 3.1 General FP32/BF16/FP16 MicroGemm based on ATen VEC ([inductor][cpp] GEMM template (infra and fp32) #124021 etc.)
- 3.2 BF16 AMX MicroGemm for x86 ([inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion #126068 etc.)
- 3.3 FP16 AMX MicroGemm for x86
- 3.4 INT8 AMX MicroGemm for x86 ([Inductor][CPP] Enable Quantized Linear with AMX MicroGEMM #129220)
- 3.5 INT8 Weight-quantized MicroGemm for x86 (Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue #131887)
- 3.6 INT4 Weight-quantized MicroGemm for x86
- 3.7 MicroGemms for ARM
- Cpp Template
- 4.1 Single GEMM, packed ([inductor][cpp] GEMM template (infra and fp32) #124021, [RELAND][inductor][cpp] bf16/fp16 gemm template computed with fp32 #128472, [inductor][cpp] support bf16/fp16 gemm template epilogue fusion #126545, [inductor][cpp] epilogue support for gemm template #126019, [Inductor][CPP] Enable Quantized Linear GEMM Template with FP32 output #128825, [Inductor][CPP] Enable Quantized Linear GEMM Template with INT8 output and Unary Post Op #129048, [Inductor][CPP] Enable Quantized Linear GEMM Template with Binary Fusion #129103, [Inductor][CPP] Enable Quantized Linear with AMX MicroGEMM #129220, [inductor][cpp][gemm] optimize arbitrary N in packed gemm template #130690)
- 4.2 Single GEMM, unpacked
- 4.2 BMM ([inductor][cpp] Add BMM kernel template for autotuning #129772)
- 4.3 WOQ GEMM (Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue #131887 etc.)
- 4.4 SDPA
- 4.5 MLP
- 5. Epilogue Fusion ([inductor][cpp] epilogue support for gemm template #126019 [inductor][cpp] support bf16/fp16 gemm template epilogue fusion #126545 etc.)
- Performance Tuning (ongoing work)
- 6.1 Thread blocking optimization ([inductor][cpp][gemm] improve thread blocking heuristics #131024, [inductor][cpp][gemm] support k slicing for static shapes #130821 etc.)
- 6.2 Cache blocking optimization ([inductor] [cpp] improve cache blocking with CPU info #129348, [inductor][cpp][gemm] improve large bs perf with better cache blocking #132729, [inductor] [cpp] use non-temporal tile load for A #129455 etc.)
Alternatives
No response
Additional context
No response
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire