Skip to content

HazyResearch/ThunderKittens

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ThunderKittens

ThunderKittens logo
ThunderKittens: Tile primitives for speedy kernels

ThunderKittens is a framework to make it easy to write fast deep learning kernels in CUDA. It is built around three key principles:

  1. Simplicity. ThunderKittens is stupidly simple to write.
  2. Extensibility. ThunderKittens is natively embedded into CUDA, so that if you need more than ThunderKittens can offer, it won’t get in your way of building it yourself.
  3. Speed. Kernels written in ThunderKittens should be at least as fast as those written from scratch -- especially because ThunderKittens can do things the “right” way under the hood. We think our FlashAttention-3 implementation speaks for this point.
Flash Attention 3, but with kittens!

ThunderKittens began as an internal art project and is maintained by graduate students at the Hazy Research Lab. Nonetheless, many AI companies use it for production-scale training and inference (e.g., Together AI, Jump Trading, and Cursor).

ThunderKittens is built for NVIDIA GPUs. For AMD GPUs, check out HipKittens.

Recent Updates

Jan 11, 2026: ThunderKittens 2.0 is out!

  • This release brings full support for Blackwell GPUs along with MXFP8 and NVFP4 precision, and merges major contributions from across the industry.
  • The repository structure has changed. We no longer support the repo as a Python package (i.e., a top-level setup.py). Kernels under the /kernels directory must now be compiled individually. Makefiles, tests, and benchmarks reside alongside their corresponding kernel source files.
  • We no longer actively support Ampere GPUs. While ThunderKittens should still work on Ampere, we do not plan to bring further support to it.

Overview

ThunderKittens is built from the hardware up; we do what the silicon tells us. And modern GPUs tell us that they want to work with fairly small tiles of data. A GPU is not really a 1000x1000 matrix multiply machine (even if it is often used as such); it’s a manycore processor where each core can efficiently run ~16x16 matrix multiplies. Consequently, ThunderKittens is built around manipulating tiles of data no smaller than 16x16 values.

ThunderKittens makes a few tricky things easy that enable high utilization on modern hardware.

  1. Tensor cores. ThunderKittens can call fast tensor core functions, including asynchronous WGMMA calls on H100 GPUs and TCGEN05 calls on B200 GPUs.
  2. Shared Memory. I got ninety-nine problems but a bank conflict ain’t one.
  3. Loads and stores. Hide latencies with asynchronous copies and address generation with TMA.
  4. Distributed Shared Memory. L2 is so last year.
  5. Worker overlapping. Use our Load-Store-Compute-Finish template to overlap work and I/O.
  6. GPU networking. ThunderKittens lets you transfer data over NVLink and utilize NVSwitch acceleration for fast multi-GPU operations.

Example: A Simple Matrix Multiplication Kernel

For example, here’s an example of what a simple matrix multiplication kernel for an H100 looks like written in ThunderKittens.

#include "kittens.cuh"
#include "prototype.cuh"

using namespace kittens;
using namespace kittens::prototype;
using namespace kittens::prototype::lcf;

template<int M_BLOCK, int N_BLOCK>
struct matmul_layout {
    using  base_tile      = st_bf<64, 64>;
    using  global_layout  = gl<bf16, 1, 1, -1, -1, base_tile>;
    struct globals        { global_layout A, B, C; };
    struct input_block    { base_tile a[M_BLOCK], b[N_BLOCK]; };
    struct finish_block   { base_tile c[M_BLOCK][N_BLOCK]; };
    struct common_state   { int2 coord; };
    struct consumer_state { rt_fl<16, N_BLOCK*base_tile::cols> accum; };
};
template<int _M_BLOCK=2, int _N_BLOCK=4, int _SUPER_M=12>
struct matmul_template {
    static constexpr int M_BLOCK = _M_BLOCK, N_BLOCK = _N_BLOCK, SUPER_M = _SUPER_M;
    using layout    = matmul_layout<M_BLOCK, N_BLOCK>;
    using wide_tile = st_bf<64, 64*N_BLOCK>;
    static constexpr int NUM_CONSUMER_WARPS=M_BLOCK*4, INPUT_PIPE_STAGES=4, PRODUCER_BARRIER_ARRIVALS=1;
    // Helper functions
    template<bool PERISISTENT_GRID=true> __host__ static inline dim3 grid(int M, int N, int K) {
        return dim3(PERISISTENT_GRID ? 132 : M*N/(M_BLOCK*N_BLOCK*layout::base_tile::num_elements));
    }
    // ThunderKittens template functions
    __device__ static inline void common_setup(common_setup_args<layout> args) {
        int Rblocks = args.globals.C.rows() / (M_BLOCK*64), Cblocks = args.globals.C.cols() / (N_BLOCK*64);
        int super_rows = (Rblocks/SUPER_M)*SUPER_M,
            final_rows = Rblocks - super_rows,
            super_repeat = SUPER_M*Cblocks;
        int task_id = args.task_iter*gridDim.x + blockIdx.x;
        if (task_id < super_rows * Cblocks)
            args.common.coord = { SUPER_M*(task_id/super_repeat) + task_id%SUPER_M,
                           (task_id%super_repeat)/SUPER_M };
        else if (task_id < Rblocks*Cblocks) {
            int remainder_id = task_id - super_rows*Cblocks;
            args.common.coord = { super_rows + (remainder_id%final_rows), remainder_id/final_rows };
        }
        else { // Id is too high, no more work to do
            args.num_iters = -1;
            return;
        }
        args.num_iters = args.globals.A.cols()/64;
        int id = warpgroup::groupid() == NUM_CONSUMER_WARPS/4 ? 0 : warpgroup::groupid(); // producer sets as 0
        args.common.coord = { args.common.coord.x*M_BLOCK + id, args.common.coord.y*N_BLOCK };
    }
    struct producer {
        __device__ static void setup(producer_setup_args<layout> args) {
            warpgroup::decrease_registers<40>(); // decrease registers for producers
        }
        __device__ static void load(producer_load_args<layout> args) {
            if (warpgroup::laneid() == 0) {
                tma::expect(args.inputs_arrived, args.input);
                for(int i = 0; i < M_BLOCK; i++)
                    tma::load_async(args.input.a[i], args.globals.A,
                                    {args.common.coord.x+i, args.iter}, args.inputs_arrived);
                for(int i = 0; i < N_BLOCK; i++)
                    tma::load_async(args.input.b[i], args.globals.B,
                                    {args.iter, args.common.coord.y+i}, args.inputs_arrived);
            }
        }
    };
    struct consumer {
        __device__ static void setup(consumer_setup_args<layout> args) {
            warpgroup::increase_registers<232>(); // increase registers for consumers
            kittens::warp::zero(args.state.accum);
        }
        __device__ static void compute(consumer_compute_args<layout> args) {
            warpgroup::mma_AB(
                args.state.accum, // dest registers
                args.input.a[warpgroup::groupid()], // A matrix
                reinterpret_cast<wide_tile&>(args.input.b) // B matrix
            );
            warpgroup::mma_async_wait();
            if (warp::laneid() == 0) arrive(args.inputs_finished);
        }
        __device__ static void finish(consumer_finish_args<layout> args) {
            warpgroup::store(reinterpret_cast<wide_tile&>(args.finish.c[warpgroup::groupid()]), args.state.accum);
            warpgroup::sync(warpgroup::groupid()+4);
            if (warpgroup::laneid() == 0) for(int i = 0; i < N_BLOCK; i++) {
                tma::store_async(args.globals.C, args.finish.c[warpgroup::groupid()][i],
                                             {args.common.coord.x, args.common.coord.y+i});
                tma::store_async_read_wait(); // wait that store is finished before reusing finish memory
            }
            kittens::warp::zero(args.state.accum);
            if (warp::laneid() == 0) arrive(args.finish_finished);
        }
    };
};

Altogether, this is less than 100 lines of code, and achieves about 855 TFLOPs on an H100 (86% of theoretical max). We’ll go through some of these primitives more carefully later, in the ThunderKittens Manual section.

Installation

ThunderKittens itself is a header-only library. The library itself does not require any installation; just clone the repo, and include kittens.cuh. Easy money.

Hardware requirements

  • ThunderKittens is mainly built and tested for Hopper and Blackwell GPUs.
  • We no longer actively support Ampere GPUs. However, contributions are welcomed!

Build requirements

ThunderKittens does use a bunch of modern stuff, so it has fairly aggressive requirements.

  • CUDA 12.8+. Anything after CUDA 12.1 will probably work, but you'll likely end up with serialized wgmma pipelines on H100s due to a bug in those earlier versions of CUDA. We do our dev work on CUDA 12.8-13.1, because we want our kittens to play in the nicest, most modern environment possible. Make sure you run the following to set up your CUDA environment properly:

    export CUDA_HOME=/usr/local/cuda-<YOUR-CUDA-VERSION> # ex. cuda-12.6
    export PATH=${CUDA_HOME}/bin:${PATH} 
    export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH
  • C++20. TK runs on concepts. If you get weird compilation errors, chances are your gcc is out of date. Update your compiler with:

    sudo apt update
    sudo apt install gcc-11 g++-11
    
    sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11
    
    sudo apt update
    sudo apt install clang-11

Sometimes, there's a libc10.so error, which you can fix with:

# Take the <PRINTED_PATH> from below
python -c "import torch; print(torch.file)"

# And run the command below
export LD_LIBRARY_PATH=<PRINTED_PATH>/lib:$LD_LIBRARY_PATH

ThunderKittens Manual

ThunderKittens is actually a pretty small library, in terms of what it gives you.

  • Data types: (Register + shared) * (tiles + vectors), all parameterized by layout, type, and size.
  • Operations for manipulating these objects.

Therefore, the best way to learn ThunderKittens is to start looking into kernels and run the them yourself! We have a step-by-step, educational kernel series on matrix multiplication under kernels/gemm/educational_h100.

Once you get used to its APIs, there are still a few sharp edges that you might encounter if you don’t know what’s going on under the hood. So, we do recommend giving this manual a good read before sitting down to write a serious kernel -- it’s not too long, we promise!

NVIDIA’s Programming Model

To understand ThunderKittens, it will help to begin by reviewing a bit of how NVIDIA’s programming model works, as NVIDIA provides a few different “scopes” to think about when writing parallel code.

  1. Thread: this is the level of doing work on an individual bit of data, like a floating point multiplication. A thread has up to 256 32-bit registers it can access every cycle.
  2. Warp: 32 threads make a warp. This is the level at which instructions are issued by the hardware. It’s also the base (and default) scope from which ThunderKittens operates; most ThunderKittens programming happens here.
  3. Warpgroup: 4 warps make a warpgroup. This is the level from which asynchronous warpgroup matrix multiply-accumulate instructions are issued. (We really wish we could ignore this level, but you unfortunately need it for the H100.) Correspondingly, many matrix multiply and memory operations are supported at the warpgroup level.
  4. Block: N warps make a block, which is the level that shares “shared memory” in the CUDA programming model. In ThunderKittens, N is often 8.
  5. Grid: M blocks make a grid, where M should be equal to (or slightly less) than a multiple of the number of SMs on the GPU to avoid tail effects. ThunderKittens does not touch the grid scope except through helping initialize TMA descriptors.

“Register” objects exist at the level of warps; their contents are split amongst the threads of the warp. Register objects include:

  • Register tiles, declared as the kittens::rt struct in src/register_tile/rt.cuh. Kittens provides a few useful wrappers -- for example, a 32 row, 16 column, row-layout bfloat16 register tile can be declared as kittens::rt_bf<32,16>; -- row-layout is implicit by default.
  • Register vectors, which are associated with register tiles. They come in three flavors: naive, aligned, and orthogonal. What's going on under the hood is a bit too complicated for a readme, but what you need to know is that the naive layout is used for when you expect to do lots of compute on vectors (like a layernorm), and otherwise you should just instantiate column or row vectors depending on how you want to interact with a tile, and let TK take care of the layout for you. Column vectors are used to reduce or map across tile rows (it's a single column of the tile), and row vectors reduce and map across tile columns (a single row of the tile). For example, to hold the sum of the rows of the tile declared above, we would create a kittens::rt_bf<32,16>::col_vec;

In contrast, “Shared” objects exist at the level of the block, and sit only in shared memory.

All ThunderKittens functions follow a common signature. Much like an assembly language (ThunderKittens' origin comes from thinking about an idealized tile-oriented RISC instruction set), the destination of every function is the first operand, and the source operands are passed sequentially afterwards.

For example, if we have three 32 row, 64 col floating point register tiles: kittens::rt_fl<32,64> a, b, c;, we can element-wise multiply a and b and store the result in c with the following call: kittens::mul(c, a, b);.

Similarly, if we want to then store the result into a half-precision shared tile __shared__ kittens:st_hf<32, 64> s;, we write the function analogously: kittens::store(s, c);.

Typing

ThunderKittens tries hard to protect you from yourself. In particular, ThunderKittens wants to know layouts of objects at compile-time and will make sure they’re compatible before letting you do operations. This is important because there are subtleties to the allowable layouts for certain operations, and without static checks it is very easy to get painful silent failures. For example, a normal matrix multiply requires the B operand to be in a column layout, whereas an outer dot product requires the B operand to be in a row layout.

If you are being told an operation that you think exists doesn't exist, double-check your layouts -- this is the most common error. Only then report a bug :)

Scopes

By default, ThunderKittens operations exist at the warp-level. In other words, each function expects to be called by only a single warp, and that single warp will do all of the work of the function. If multiple warps are assigned to the same work, undefined behavior will result. (And if the operation involves memory movement, it is likely to be completely catastrophic.) In general, you should expect your programming pattern to involve instantiating a warpid at the beginning of the kernel with kittens::warpid(), and assigning tasks to data based on that id.

However, not all ThunderKittens functions operate at the warp level. Many important operations, particularly WGMMA instructions, require collaborative groups of warps. These operations exist in the templated kittens::group<collaborative size>. For example, wgmma instructions are available through kittens::group<4>::mma_AB (or kittens::warpgroup::mma_AB, which is an alias.) Groups of warps can also collaboratively load shared memory or do reductions in shared memory

Other Restrictions

Most operations in ThunderKittens are pure functional. However, some operations do have special restrictions; ThunderKittens tries to warn you by giving them names that stand out. For example, a register tile transpose needs separable arguments: if it is given the same underlying registers as both source and destination, it will silently fail. Consequently, it is named transpose_sep.

Onboarding document

We have a slightly outdated and incomplete onboarding document. Please contribute to this if you've run into issues and feel the broader community can benefit from explanations.

Pre-implemented Kernels

We've provided a number of ThunderKittens kernels in the kernels/ folder, which can be easily called from your PyTorch code. To use these kernels:

  1. Make sure the currently activated Python environment has PyTorch 2.8+ and PyBind11 installed. Ensure your PyTorch version meets the CUDA version (follow the official instructions from PyTorch).

  2. (Optional) Set environment variables. Our build system sets this for you, but it's quite slow to set it every time. So it is recommended to set them first.

    # Make sure the Python environment you want to use is active and is called by `python3`.
    export PYTHON_VERSION=$(python3 -c "import sysconfig; print(sysconfig.get_config_var('LDVERSION'))")
    export PYTHON_INCLUDES=$(python3 -c "import sysconfig; print('-I', sysconfig.get_path('include'), sep='')")
    export PYBIND_INCLUDES=$(python3 -m pybind11 --includes)
    export PYTORCH_INCLUDES=$(python3 -c "from torch.utils.cpp_extension import include_paths; print(' '.join(['-I' + p for p in include_paths()]))")
    export PYTHON_LIBDIR=$(python3 -c "import sysconfig; print('-L', sysconfig.get_config_var('LIBDIR'), sep='')")
    export PYTORCH_LIBDIR=$(python3 -c "from torch.utils.cpp_extension import library_paths; print(' '.join(['-L' + p for p in library_paths()]))")
  3. cd into the kernel directory you are interested in (e.g., kernels/gemm/bf16_h100).

  4. Open the Makefile and change the configuration to your needs. This depends on each kernel and most of them should work out of the box.

  5. Build:

    make
  6. Run:

    make run

The correctness tests and benchmarks for these kernels are located alongside their source files. Note that the top-level tests/ directory is irrelevant to this and only contains tests for the ThunderKittens primitives.

We intentionally keep each kernel self-contained rather than using a shared harness or setup, to make it easy for anyone to add new kernels. For production environments, we recommend wrapping the kernels into your own Python package.

Demos

Kitten workers

We've included a set of starter demos in the demos/ folder, showing how to use TK kernels for training and LLM inference (Qwens, Llamas, LoLCATS LLMs, etc.)!

We're also excited to feature any demos you build, please link PRs!

General setup

Several of these demos are set up to use large 8B models from Hugging Face. To setup, run login:

huggingface-cli login

Set the directory at which you want the models to download in the _model_config.yaml file in the demos/configs/ directory.

Attention

Attention powers a large number of current LLMs. TK includes forwards / prefill and backwards kernels. We include causal and non-causal variants.

We include LLM inference integrations:

cd llama_demo/
bash demo_8b.sh

And enter your prompt, e.g., "The capital of America is"

LoLCATS

LoLCATS is a recent state-of-the-art method for converting quadratic attention Transformer LLMs to linear attention LLMs. TK includes a forwards / prefill kernel.

We include:

cd lolcats_demo/
bash demo_8b.sh

And enter your prompt, e.g., "The capital of America is"

Based

Based is a linear attention architecture that combines short sliding window attentions with large-state-size linear attentions. TK includes a forwards / prefill kernel.

Added installs:

pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

We include:

  • Based 1.3B with TK on a series of recall-intensive in-context learning tasks.

Your Demos!

If you use TK to build any demos, please reach out / make a PR! We'd love to feature it here!!

Tests

ThunderKittens has fairly comprehensive unit testing suite. Simply run make -j in the tests/ folder. Be warned: this may nuke your computer for a minute or two while it compiles thousands of kernels.

Compilation Options

The tests/Makefile provides several options to customize the test:

  • GPU_TARGET: Set to either H100 or B200 to specify the target GPU architecture (default: H100).
  • COMP_LEVEL: Set the compiler optimization level. Available options are fast, debug, and profile (default: fast).
  • TEST_INTENSITY: Set the level of test intensity. Higher levels compile more tests but take longer. Available options are 1, 2, 3, and 4 (default: 2).
  • TEST_ALL: Compile and run all available tests. You can also specify individual test sections or tests using flags like -DTEST_WARP_MEMORY or -DTEST_WARP_MEMORY_VEC_DSMEM.

Running the Tests

After successful compilation, run the tests using:

mkdir outputs
./unit_tests printout

This will execute the compiled unit tests and dump results of any failed tests to the outputs/ folder. As a quick note, it is expected for mma tests to occasionally fail. Careful inspection of the output will usually show just a single element differing by a small amount, which we think is due to how floating-point arithmetic is implemented within the tensor cores.

Cleaning the Build

To clean the build directory and remove the compiled binary, run:

make clean

Learn more and get involved!

Learn more about ThunderKittens and how GPUs work by checking out our blog posts:

Explore the Kittens Cinematic Universe:

Please check out our papers for even more details!

Finally, join us on Discord to get involved: ThunderKittens channel @ GPU Mode Discord!!!! Here is the invite link to GPU mode: https://discord.gg/gpumode

License

This project is licensed under the terms of the MIT license.

About

Tile primitives for speedy kernels

Resources

License

Stars

Watchers

Forks

Contributors 26