FlexGEMM is a high-performance, Triton-powered GEMM backend designed for 3D sparse convolutions.
It implements Explicit, Implicit, and Masked Implicit algorithm variants, featuring optional Split-K parallelism for sparse GEMM. FlexGEMM delivers state-of-the-art performance for Submanifold Convolution and voxel-based neural networks, consistently outperforming existing solutions.
- Deep Dive: Read the technical blog at JeffreyXiang's Blog.
- Real-world Demo: See FlexGEMM in action in the TRELLIS.2 project.
- Triton-First Architecture: Built entirely on Triton, ensuring high-performance kernel execution and cross-platform compatibility.
- Sparse-Optimized: Specifically tailored for 3D sparse tensors, efficiently handling highly irregular sparsity patterns.
- Blazing Fast: Consistently outperforms standard sparse convolution libraries (such as
spconv,torchsparse) in training throughput.
- PyTorch ≥ 2.4.0
- Triton ≥ 3.2.0
git clone https://github.com/JeffreyXiang/FlexGEMM.git
cd FlexGEMM
pip install . --no-build-isolationHere is a minimal example demonstrating how to perform a sparse submanifold convolution using FlexGEMM:
import torch
import flex_gemm
from flex_gemm.ops.spconv import sparse_submanifold_conv3d
from tests.spconv_fwd import sphere_coords
# 1. Prepare Sparse Voxel Data
# Generate a sparse voxel shell
feats, coords, shape = sphere_coords(256, 256, dtype=torch.float16, device='cuda')
# 2. Define Weights and Bias
Ci, Co = 256, 256
Ks = 3
weight = torch.randn(Co, Ks, Ks, Ks, Ci, dtype=torch.float16, device='cuda', requires_grad=True)
bias = torch.randn(Co, dtype=torch.float16, device='cuda', requires_grad=True)
# 3. Configure Algorithm
# Example: Using Masked Implicit GEMM with Split-K optimization
flex_gemm.ops.spconv.set_algorithm(
flex_gemm.ops.spconv.Algorithm.MASKED_IMPLICIT_GEMM_SPLITK
)
# 4. Forward Pass
out_feats, neighbor_cache = sparse_submanifold_conv3d(
feats, coords, shape,
weight, bias,
)
# 5. Backward Pass
out_feats.sum().backward()FlexGEMM demonstrates significant speed improvements over existing baselines.
Test Environment:
- GPU: NVIDIA A100 80GB PCIe
- Software: PyTorch 2.4.1, CUDA 12.0, Triton 3.2.0
Note: FlexGEMM achieves ~2× acceleration compared to previous state-of-the-art methods under efficient data formats like FP16 and TF32.
- SOTA Speed: Consistently outperforms
spconv,torchsparse, andfvdb. - Scalability: Robust performance across various channel widths (C=64 to C=1024) and resolutions (RES=8 to RES=1024).
- Memory Efficient: Delivers higher throughput without increasing GPU memory overhead.
- Application Ready: Ideal for high-resolution voxelized point clouds, submanifold convolutions, and large-scale 3D networks.
We welcome contributions to make FlexGEMM faster and more robust!
- Report Bugs: Open an issue describing the bug and how to reproduce it.
- Suggest Features: Have an idea for a new algorithm or optimization? Let us know!
- Submit Pull Requests:
- Fork the repository and create your branch from
main. - Ensure your code follows the project's style.
- Run the tests in the
tests/directory to ensure no regressions. - Open a Pull Request with a detailed description.
- Fork the repository and create your branch from
We appreciate all contributors who help improve this project!
This project is released under the MIT License.


