Skip to content

Commit ca8f09d

Browse files
committed
Add custom indexer k cache scatter op
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 222bc91 commit ca8f09d

File tree

6 files changed

+464
-19
lines changed

6 files changed

+464
-19
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
21+
namespace tensorrt_llm::kernels
22+
{
23+
24+
void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
25+
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
26+
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
27+
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3,
28+
cudaStream_t stream = 0);
29+
30+
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "IndexerKCacheScatter.h"
18+
#include "tensorrt_llm/common/assert.h"
19+
#include "tensorrt_llm/common/cudaUtils.h"
20+
21+
namespace tensorrt_llm::kernels
22+
{
23+
24+
namespace
25+
{
26+
/**
27+
* Given a flat element index and tensor shape [d0, d1, d2, d3] with strides [s0, s1, s2, s3],
28+
* find the actual memory offset within the given k cache pool using the strides.
29+
*/
30+
__device__ __forceinline__ int64_t flatIndexToMemoryOffset(
31+
int64_t flat_idx, int32_t d0, int32_t d1, int32_t d2, int32_t d3, int64_t s0, int64_t s1, int64_t s2, int64_t s3)
32+
{
33+
// Unravel from innermost to outermost dimension
34+
int32_t i3 = flat_idx % d3;
35+
flat_idx /= d3;
36+
37+
int32_t i2 = flat_idx % d2;
38+
flat_idx /= d2;
39+
40+
int32_t i1 = flat_idx % d1;
41+
flat_idx /= d1;
42+
43+
int32_t i0 = flat_idx;
44+
45+
// Compute memory offset using strides
46+
return i0 * s0 + i1 * s1 + i2 * s2 + i3 * s3;
47+
}
48+
49+
} // namespace
50+
51+
/**
52+
* CUDA kernel to scatter both FP8 K values and scales into the indexer k cache pool
53+
*
54+
* @param k_fp8_bytes Quantized FP8 data [num_tokens, 128]
55+
* @param k_scale_bytes Quantized scales (1 per token) [num_tokens, 4]
56+
* @param k_cache Indexer k cache pool with shape [num_blocks, block_size, 1, per_token_size] (can be
57+
* non-contiguous)
58+
* @param slot_mapping_fp8 Flat element index for FP8 data start position [num_tokens]
59+
* @param slot_mapping_scale Flat element index for scale data start position [num_tokens]
60+
* @param num_tokens Number of tokens
61+
* @param head_dim Head dimension (must be 128)
62+
* @param scale_size Scale size in bytes (must be 4)
63+
* @param cache_stride_0 Stride for k_cache dimension 0 (in bytes)
64+
* @param cache_stride_1 Stride for k_cache dimension 1 (in bytes)
65+
* @param cache_stride_2 Stride for k_cache dimension 2 (in bytes)
66+
* @param cache_stride_3 Stride for k_cache dimension 3 (in bytes)
67+
* @param cache_dim_0 Size of k_cache dimension 0
68+
* @param cache_dim_1 Size of k_cache dimension 1
69+
* @param cache_dim_2 Size of k_cache dimension 2
70+
* @param cache_dim_3 Size of k_cache dimension 3
71+
*/
72+
__global__ void indexerKCacheScatterUnifiedKernel(uint8_t const* __restrict__ k_fp8_bytes,
73+
uint8_t const* __restrict__ k_scale_bytes, uint8_t* __restrict__ k_cache,
74+
int64_t const* __restrict__ slot_mapping_fp8, int64_t const* __restrict__ slot_mapping_scale, int32_t num_tokens,
75+
int32_t head_dim, int32_t scale_size, int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2,
76+
int64_t cache_stride_3, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3)
77+
{
78+
// For head_dim=128, each thread handles 4 bytes/elements per read/write instruction
79+
constexpr int VEC_SIZE = 4;
80+
81+
// Token index from block.x
82+
int32_t token_idx = blockIdx.x;
83+
84+
if (token_idx >= num_tokens)
85+
{
86+
return;
87+
}
88+
89+
int64_t flat_idx_fp8_base = slot_mapping_fp8[token_idx];
90+
int64_t flat_idx_scale_base = slot_mapping_scale[token_idx];
91+
92+
if (flat_idx_fp8_base < 0 || flat_idx_scale_base < 0)
93+
{
94+
return;
95+
}
96+
97+
int32_t head_dim_idx = threadIdx.x * VEC_SIZE;
98+
int64_t flat_idx = flat_idx_fp8_base + head_dim_idx;
99+
100+
// Convert flat index to memory offset using strides (k cache pool from cpp kv cache manager is non-contiguous)
101+
int64_t dst_offset = flatIndexToMemoryOffset(flat_idx, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3,
102+
cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
103+
int64_t src_offset = token_idx * head_dim + head_dim_idx;
104+
105+
// 4 bytes write
106+
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset]) = *reinterpret_cast<uint32_t const*>(&k_fp8_bytes[src_offset]);
107+
108+
// Only thread 0 writes the single 4 bytes scale value
109+
if (threadIdx.x == 0)
110+
{
111+
int64_t dst_offset_scale = flatIndexToMemoryOffset(flat_idx_scale_base, cache_dim_0, cache_dim_1, cache_dim_2,
112+
cache_dim_3, cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3);
113+
int64_t src_offset_scale = token_idx * scale_size; // scale_size = 4
114+
115+
// 4 bytes write for scale
116+
*reinterpret_cast<uint32_t*>(&k_cache[dst_offset_scale])
117+
= *reinterpret_cast<uint32_t const*>(&k_scale_bytes[src_offset_scale]);
118+
}
119+
}
120+
121+
void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache,
122+
int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim,
123+
int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3,
124+
int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, cudaStream_t stream)
125+
{
126+
if (num_tokens == 0)
127+
{
128+
return;
129+
}
130+
131+
// Assertions for DeepSeek-V3 configuration
132+
constexpr int32_t QUANT_BLOCK_SIZE = 128;
133+
TLLM_CHECK_WITH_INFO(
134+
head_dim == QUANT_BLOCK_SIZE, "head_dim must equal 128 for DeepSeek-V3 indexer cache (got %d)", head_dim);
135+
TLLM_CHECK_WITH_INFO(
136+
scale_size == 4, "scale_size must equal 4 bytes (1 float32 scale per token, got %d)", scale_size);
137+
138+
// For head_dim=128, we use 32 threads to handle 128 bytes per token and extra 4 bytes for scale
139+
constexpr int32_t THREADS_PER_BLOCK = 32;
140+
141+
dim3 block(THREADS_PER_BLOCK);
142+
dim3 grid(num_tokens);
143+
144+
indexerKCacheScatterUnifiedKernel<<<grid, block, 0, stream>>>(k_fp8_bytes, k_scale_bytes, k_cache, slot_mapping_fp8,
145+
slot_mapping_scale, num_tokens, head_dim, scale_size, cache_stride_0, cache_stride_1, cache_stride_2,
146+
cache_stride_3, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3);
147+
148+
// Check for kernel launch errors
149+
TLLM_CUDA_CHECK(cudaGetLastError());
150+
}
151+
152+
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ add_library(
8383
fp8PerTensorScaleMoe.cpp
8484
fp4BlockScaleMoe.cpp
8585
noAuxTcOp.cpp
86+
IndexerKCacheScatterOp.cpp
8687
ncclCommunicatorOp.cpp
8788
parallelDecodeKVCacheUpdateOp.cpp
8889
redrafterCurandOp.cpp
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/opUtils.h"
18+
#include "tensorrt_llm/runtime/torchUtils.h"
19+
20+
#include "tensorrt_llm/kernels/IndexerKCacheScatter.h"
21+
22+
namespace th = torch;
23+
namespace tl = tensorrt_llm;
24+
namespace tk = tensorrt_llm::kernels;
25+
26+
namespace torch_ext
27+
{
28+
29+
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
30+
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
31+
{
32+
// Validate all tensors are CUDA tensors
33+
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
34+
&& slot_mapping_scale.is_cuda(),
35+
"All tensors must be CUDA tensors");
36+
37+
// Validate tensor dimensions
38+
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
39+
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
40+
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
41+
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");
42+
43+
// Enforce k_cache is 4D tensor
44+
TORCH_CHECK(k_cache.dim() == 4,
45+
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
46+
static_cast<int>(k_cache.dim()));
47+
48+
// Validate tensor dtypes
49+
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
50+
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
51+
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
52+
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");
53+
54+
// Validate tensor shapes are consistent
55+
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
56+
TORCH_CHECK(
57+
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
58+
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
59+
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");
60+
61+
// Validate tensors are contiguous (except k_cache which may be non-contiguous)
62+
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
63+
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
64+
// k_cache can be non-contiguous - we handle this via strides
65+
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
66+
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");
67+
68+
int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
69+
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes
70+
71+
int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
72+
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
73+
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
74+
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size
75+
76+
// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
77+
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V32, got %d", cache_dim_2);
78+
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V32, got %d", head_dim);
79+
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V32, got %d", scale_size);
80+
81+
int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
82+
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
83+
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
84+
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
85+
86+
auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());
87+
88+
tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
89+
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
90+
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
91+
cache_stride_1, cache_stride_2, cache_stride_3, stream);
92+
}
93+
94+
} // namespace torch_ext
95+
96+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
97+
{
98+
m.def(
99+
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
100+
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
101+
}
102+
103+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
104+
{
105+
m.impl("indexer_k_cache_scatter_op", &torch_ext::indexer_k_cache_scatter_op);
106+
}

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -872,24 +872,12 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor,
872872
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
873873
num_tokens, scale_size)
874874

875-
# Scatter FP8 data
875+
# Use CUDA kernel to scatter FP8 and scale bytes into cache
876876
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
877-
byte_offsets = torch.arange(head_dim, device=k_cache.device).unsqueeze(
878-
0) # [1, head_dim]
879-
scatter_indices_fp8 = flat_indices_fp8.unsqueeze(
880-
1) + byte_offsets # [num_tokens, head_dim]
881-
scatter_indices_fp8 = _unravel_indices(scatter_indices_fp8,
882-
k_cache.shape)
883-
k_cache[scatter_indices_fp8] = k_fp8_bytes
884-
885877
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
886-
byte_offsets = torch.arange(
887-
scale_size, device=k_cache.device).unsqueeze(0) # [1, scale_size]
888-
scatter_indices_scale = flat_indices_scale.unsqueeze(
889-
1) + byte_offsets # [num_tokens, scale_size]
890-
scatter_indices_scale = _unravel_indices(scatter_indices_scale,
891-
k_cache.shape)
892-
k_cache[scatter_indices_scale] = k_scale_bytes
878+
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
879+
k_cache, flat_indices_fp8,
880+
flat_indices_scale)
893881

894882
def _gather_k_cache_for_chunk(
895883
self,

0 commit comments

Comments
 (0)