|
| 1 | +/* |
| 2 | + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "IndexerKCacheScatter.h" |
| 18 | +#include "tensorrt_llm/common/assert.h" |
| 19 | +#include "tensorrt_llm/common/cudaUtils.h" |
| 20 | + |
| 21 | +namespace tensorrt_llm::kernels |
| 22 | +{ |
| 23 | + |
| 24 | +namespace |
| 25 | +{ |
| 26 | +/** |
| 27 | + * Given a flat element index and tensor shape [d0, d1, d2, d3] with strides [s0, s1, s2, s3], |
| 28 | + * find the actual memory offset within the given k cache pool using the strides. |
| 29 | + */ |
| 30 | +__device__ __forceinline__ int64_t flatIndexToMemoryOffset( |
| 31 | + int64_t flat_idx, int32_t d0, int32_t d1, int32_t d2, int32_t d3, int64_t s0, int64_t s1, int64_t s2, int64_t s3) |
| 32 | +{ |
| 33 | + // Unravel from innermost to outermost dimension |
| 34 | + int32_t i3 = flat_idx % d3; |
| 35 | + flat_idx /= d3; |
| 36 | + |
| 37 | + int32_t i2 = flat_idx % d2; |
| 38 | + flat_idx /= d2; |
| 39 | + |
| 40 | + int32_t i1 = flat_idx % d1; |
| 41 | + flat_idx /= d1; |
| 42 | + |
| 43 | + int32_t i0 = flat_idx; |
| 44 | + |
| 45 | + // Compute memory offset using strides |
| 46 | + return i0 * s0 + i1 * s1 + i2 * s2 + i3 * s3; |
| 47 | +} |
| 48 | + |
| 49 | +} // namespace |
| 50 | + |
| 51 | +/** |
| 52 | + * CUDA kernel to scatter both FP8 K values and scales into the indexer k cache pool |
| 53 | + * |
| 54 | + * @param k_fp8_bytes Quantized FP8 data [num_tokens, 128] |
| 55 | + * @param k_scale_bytes Quantized scales (1 per token) [num_tokens, 4] |
| 56 | + * @param k_cache Indexer k cache pool with shape [num_blocks, block_size, 1, per_token_size] (can be |
| 57 | + * non-contiguous) |
| 58 | + * @param slot_mapping_fp8 Flat element index for FP8 data start position [num_tokens] |
| 59 | + * @param slot_mapping_scale Flat element index for scale data start position [num_tokens] |
| 60 | + * @param num_tokens Number of tokens |
| 61 | + * @param head_dim Head dimension (must be 128) |
| 62 | + * @param scale_size Scale size in bytes (must be 4) |
| 63 | + * @param cache_stride_0 Stride for k_cache dimension 0 (in bytes) |
| 64 | + * @param cache_stride_1 Stride for k_cache dimension 1 (in bytes) |
| 65 | + * @param cache_stride_2 Stride for k_cache dimension 2 (in bytes) |
| 66 | + * @param cache_stride_3 Stride for k_cache dimension 3 (in bytes) |
| 67 | + * @param cache_dim_0 Size of k_cache dimension 0 |
| 68 | + * @param cache_dim_1 Size of k_cache dimension 1 |
| 69 | + * @param cache_dim_2 Size of k_cache dimension 2 |
| 70 | + * @param cache_dim_3 Size of k_cache dimension 3 |
| 71 | + */ |
| 72 | +__global__ void indexerKCacheScatterUnifiedKernel(uint8_t const* __restrict__ k_fp8_bytes, |
| 73 | + uint8_t const* __restrict__ k_scale_bytes, uint8_t* __restrict__ k_cache, |
| 74 | + int64_t const* __restrict__ slot_mapping_fp8, int64_t const* __restrict__ slot_mapping_scale, int32_t num_tokens, |
| 75 | + int32_t head_dim, int32_t scale_size, int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, |
| 76 | + int64_t cache_stride_3, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3) |
| 77 | +{ |
| 78 | + // For head_dim=128, each thread handles 4 bytes/elements per read/write instruction |
| 79 | + constexpr int VEC_SIZE = 4; |
| 80 | + |
| 81 | + // Token index from block.x |
| 82 | + int32_t token_idx = blockIdx.x; |
| 83 | + |
| 84 | + if (token_idx >= num_tokens) |
| 85 | + { |
| 86 | + return; |
| 87 | + } |
| 88 | + |
| 89 | + int64_t flat_idx_fp8_base = slot_mapping_fp8[token_idx]; |
| 90 | + int64_t flat_idx_scale_base = slot_mapping_scale[token_idx]; |
| 91 | + |
| 92 | + if (flat_idx_fp8_base < 0 || flat_idx_scale_base < 0) |
| 93 | + { |
| 94 | + return; |
| 95 | + } |
| 96 | + |
| 97 | + int32_t head_dim_idx = threadIdx.x * VEC_SIZE; |
| 98 | + int64_t flat_idx = flat_idx_fp8_base + head_dim_idx; |
| 99 | + |
| 100 | + // Convert flat index to memory offset using strides (k cache pool from cpp kv cache manager is non-contiguous) |
| 101 | + int64_t dst_offset = flatIndexToMemoryOffset(flat_idx, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, |
| 102 | + cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3); |
| 103 | + int64_t src_offset = token_idx * head_dim + head_dim_idx; |
| 104 | + |
| 105 | + // 4 bytes write |
| 106 | + *reinterpret_cast<uint32_t*>(&k_cache[dst_offset]) = *reinterpret_cast<uint32_t const*>(&k_fp8_bytes[src_offset]); |
| 107 | + |
| 108 | + // Only thread 0 writes the single 4 bytes scale value |
| 109 | + if (threadIdx.x == 0) |
| 110 | + { |
| 111 | + int64_t dst_offset_scale = flatIndexToMemoryOffset(flat_idx_scale_base, cache_dim_0, cache_dim_1, cache_dim_2, |
| 112 | + cache_dim_3, cache_stride_0, cache_stride_1, cache_stride_2, cache_stride_3); |
| 113 | + int64_t src_offset_scale = token_idx * scale_size; // scale_size = 4 |
| 114 | + |
| 115 | + // 4 bytes write for scale |
| 116 | + *reinterpret_cast<uint32_t*>(&k_cache[dst_offset_scale]) |
| 117 | + = *reinterpret_cast<uint32_t const*>(&k_scale_bytes[src_offset_scale]); |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +void invokeIndexerKCacheScatter(uint8_t const* k_fp8_bytes, uint8_t const* k_scale_bytes, uint8_t* k_cache, |
| 122 | + int64_t const* slot_mapping_fp8, int64_t const* slot_mapping_scale, int32_t num_tokens, int32_t head_dim, |
| 123 | + int32_t scale_size, int32_t cache_dim_0, int32_t cache_dim_1, int32_t cache_dim_2, int32_t cache_dim_3, |
| 124 | + int64_t cache_stride_0, int64_t cache_stride_1, int64_t cache_stride_2, int64_t cache_stride_3, cudaStream_t stream) |
| 125 | +{ |
| 126 | + if (num_tokens == 0) |
| 127 | + { |
| 128 | + return; |
| 129 | + } |
| 130 | + |
| 131 | + // Assertions for DeepSeek-V3 configuration |
| 132 | + constexpr int32_t QUANT_BLOCK_SIZE = 128; |
| 133 | + TLLM_CHECK_WITH_INFO( |
| 134 | + head_dim == QUANT_BLOCK_SIZE, "head_dim must equal 128 for DeepSeek-V3 indexer cache (got %d)", head_dim); |
| 135 | + TLLM_CHECK_WITH_INFO( |
| 136 | + scale_size == 4, "scale_size must equal 4 bytes (1 float32 scale per token, got %d)", scale_size); |
| 137 | + |
| 138 | + // For head_dim=128, we use 32 threads to handle 128 bytes per token and extra 4 bytes for scale |
| 139 | + constexpr int32_t THREADS_PER_BLOCK = 32; |
| 140 | + |
| 141 | + dim3 block(THREADS_PER_BLOCK); |
| 142 | + dim3 grid(num_tokens); |
| 143 | + |
| 144 | + indexerKCacheScatterUnifiedKernel<<<grid, block, 0, stream>>>(k_fp8_bytes, k_scale_bytes, k_cache, slot_mapping_fp8, |
| 145 | + slot_mapping_scale, num_tokens, head_dim, scale_size, cache_stride_0, cache_stride_1, cache_stride_2, |
| 146 | + cache_stride_3, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3); |
| 147 | + |
| 148 | + // Check for kernel launch errors |
| 149 | + TLLM_CUDA_CHECK(cudaGetLastError()); |
| 150 | +} |
| 151 | + |
| 152 | +} // namespace tensorrt_llm::kernels |
0 commit comments