@@ -28,69 +28,56 @@ TRTLLM_NAMESPACE_BEGIN
2828namespace torch_ext
2929{
3030
31- void indexer_k_cache_scatter_op (th::Tensor const & k_fp8_bytes , th::Tensor const & k_scale_bytes , th::Tensor& k_cache,
32- th::Tensor const & slot_mapping_fp8, th::Tensor const & slot_mapping_scale)
31+ void indexer_k_cache_scatter_op (th::Tensor const & k_fp8 , th::Tensor const & k_scale , th::Tensor& k_cache,
32+ th::Tensor const & slot_mapping_fp8, th::Tensor const & slot_mapping_scale, int64_t num_tokens )
3333{
34- // Validate all tensors are CUDA tensors
35- TORCH_CHECK (k_fp8_bytes.is_cuda () && k_scale_bytes.is_cuda () && k_cache.is_cuda () && slot_mapping_fp8.is_cuda ()
34+ // k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8
35+ // k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes
36+ // slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used
37+ // k_cache: [num_blocks, block_size, 1, per_token_size] uint8
38+
39+ TORCH_CHECK (k_fp8.is_cuda () && k_scale.is_cuda () && k_cache.is_cuda () && slot_mapping_fp8.is_cuda ()
3640 && slot_mapping_scale.is_cuda (),
3741 " All tensors must be CUDA tensors" );
3842
39- // Validate tensor dimensions
40- TORCH_CHECK (k_fp8_bytes.dim () == 2 , " k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]" );
41- TORCH_CHECK (k_scale_bytes.dim () == 2 , " k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]" );
42- TORCH_CHECK (slot_mapping_fp8.dim () == 1 , " slot_mapping_fp8 must be a 1D Tensor [num_tokens]" );
43- TORCH_CHECK (slot_mapping_scale.dim () == 1 , " slot_mapping_scale must be a 1D Tensor [num_tokens]" );
44-
45- // Enforce k_cache is 4D tensor
46- TORCH_CHECK (k_cache.dim () == 4 ,
47- " k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions" ,
43+ TORCH_CHECK (k_fp8.dim () == 2 , " k_fp8 must be 2D [num_tokens, head_dim]" );
44+ TORCH_CHECK (k_scale.dim () == 2 , " k_scale must be 2D [num_tokens, scale_elements]" );
45+ TORCH_CHECK (k_cache.dim () == 4 , " k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims" ,
4846 static_cast <int >(k_cache.dim ()));
4947
50- // Validate tensor dtypes
51- TORCH_CHECK (k_fp8_bytes.scalar_type () == torch::kUInt8 , " k_fp8_bytes must be uint8" );
52- TORCH_CHECK (k_scale_bytes.scalar_type () == torch::kUInt8 , " k_scale_bytes must be uint8" );
53- TORCH_CHECK (slot_mapping_fp8.scalar_type () == torch::kInt64 , " slot_mapping_fp8 must be int64" );
54- TORCH_CHECK (slot_mapping_scale.scalar_type () == torch::kInt64 , " slot_mapping_scale must be int64" );
55-
56- // Validate tensor shapes are consistent
57- auto num_tokens = static_cast <int32_t >(k_fp8_bytes.size (0 ));
58- TORCH_CHECK (
59- k_scale_bytes.size (0 ) == num_tokens, " k_scale_bytes first dimension must equal k_fp8_bytes first dimension" );
60- TORCH_CHECK (slot_mapping_fp8.size (0 ) == num_tokens, " slot_mapping_fp8 length must equal num_tokens" );
61- TORCH_CHECK (slot_mapping_scale.size (0 ) == num_tokens, " slot_mapping_scale length must equal num_tokens" );
62-
63- // Validate tensors are contiguous (except k_cache which may be non-contiguous)
64- TORCH_CHECK (k_fp8_bytes.is_contiguous (), " k_fp8_bytes must be contiguous" );
65- TORCH_CHECK (k_scale_bytes.is_contiguous (), " k_scale_bytes must be contiguous" );
66- // k_cache can be non-contiguous - we handle this via strides
48+ TORCH_CHECK (k_fp8.is_contiguous (), " k_fp8 must be contiguous" );
49+ TORCH_CHECK (k_scale.is_contiguous (), " k_scale must be contiguous" );
6750 TORCH_CHECK (slot_mapping_fp8.is_contiguous (), " slot_mapping_fp8 must be contiguous" );
6851 TORCH_CHECK (slot_mapping_scale.is_contiguous (), " slot_mapping_scale must be contiguous" );
6952
70- int32_t head_dim = static_cast <int32_t >(k_fp8_bytes.size (1 )); // head_dim = quant_block_size = 128
71- int32_t scale_size = static_cast <int32_t >(k_scale_bytes.size (1 )); // scale_size = 4 bytes
72-
73- int32_t cache_dim_0 = static_cast <int32_t >(k_cache.size (0 )); // num_blocks
74- int32_t cache_dim_1 = static_cast <int32_t >(k_cache.size (1 )); // block_size
75- int32_t cache_dim_2 = static_cast <int32_t >(k_cache.size (2 )); // num_kv_heads
76- int32_t cache_dim_3 = static_cast <int32_t >(k_cache.size (3 )); // per_token_size
77-
78- // Validation for indexer k cache pool for DeepSeek-V3.2 constraints
79- TORCH_CHECK (cache_dim_2 == 1 , " k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d" , cache_dim_2);
80- TORCH_CHECK (head_dim == 128 , " k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d" , head_dim);
81- TORCH_CHECK (scale_size == 4 , " k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d" , scale_size);
82-
83- int64_t cache_stride_0 = static_cast <int64_t >(k_cache.stride (0 ));
84- int64_t cache_stride_1 = static_cast <int64_t >(k_cache.stride (1 ));
85- int64_t cache_stride_2 = static_cast <int64_t >(k_cache.stride (2 ));
86- int64_t cache_stride_3 = static_cast <int64_t >(k_cache.stride (3 ));
87-
88- auto stream = at::cuda::getCurrentCUDAStream (k_fp8_bytes.get_device ());
89-
90- tk::invokeIndexerKCacheScatter (k_fp8_bytes.data_ptr <uint8_t >(), k_scale_bytes.data_ptr <uint8_t >(),
91- k_cache.data_ptr <uint8_t >(), slot_mapping_fp8.data_ptr <int64_t >(), slot_mapping_scale.data_ptr <int64_t >(),
92- num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
93- cache_stride_1, cache_stride_2, cache_stride_3, stream);
53+ // FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes.
54+ int32_t const head_dim = static_cast <int32_t >(k_fp8.size (1 ));
55+ // float32 scale: each element is 4 bytes.
56+ int32_t const scale_size = static_cast <int32_t >(k_scale.size (1 )) * 4 ;
57+
58+ int32_t const cache_dim_0 = static_cast <int32_t >(k_cache.size (0 ));
59+ int32_t const cache_dim_1 = static_cast <int32_t >(k_cache.size (1 ));
60+ int32_t const cache_dim_2 = static_cast <int32_t >(k_cache.size (2 ));
61+ int32_t const cache_dim_3 = static_cast <int32_t >(k_cache.size (3 ));
62+
63+ TORCH_CHECK (cache_dim_2 == 1 , " k_cache dimension 2 must be 1, got %d" , cache_dim_2);
64+ TORCH_CHECK (head_dim == 128 , " k_fp8 head_dim must be 128, got %d" , head_dim);
65+ TORCH_CHECK (scale_size == 4 , " k_scale scale_size must be 4 bytes, got %d" , scale_size);
66+
67+ int64_t const cache_stride_0 = static_cast <int64_t >(k_cache.stride (0 ));
68+ int64_t const cache_stride_1 = static_cast <int64_t >(k_cache.stride (1 ));
69+ int64_t const cache_stride_2 = static_cast <int64_t >(k_cache.stride (2 ));
70+ int64_t const cache_stride_3 = static_cast <int64_t >(k_cache.stride (3 ));
71+
72+ auto stream = at::cuda::getCurrentCUDAStream (k_fp8.get_device ());
73+
74+ // Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr.
75+ // For slot mappings, use data_ptr directly — only the first num_tokens entries are read.
76+ tk::invokeIndexerKCacheScatter (reinterpret_cast <uint8_t const *>(k_fp8.data_ptr ()),
77+ reinterpret_cast <uint8_t const *>(k_scale.data_ptr ()), k_cache.data_ptr <uint8_t >(),
78+ slot_mapping_fp8.data_ptr <int64_t >(), slot_mapping_scale.data_ptr <int64_t >(), static_cast <int32_t >(num_tokens),
79+ head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1,
80+ cache_stride_2, cache_stride_3, stream);
9481}
9582
9683} // namespace torch_ext
@@ -100,8 +87,8 @@ TRTLLM_NAMESPACE_END
10087TORCH_LIBRARY_FRAGMENT (trtllm, m)
10188{
10289 m.def (
103- " indexer_k_cache_scatter_op(Tensor k_fp8_bytes , Tensor k_scale_bytes , Tensor(a!) k_cache, "
104- " Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()" );
90+ " indexer_k_cache_scatter_op(Tensor k_fp8 , Tensor k_scale , Tensor(a!) k_cache, "
91+ " Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens ) -> ()" );
10592}
10693
10794TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
0 commit comments