@@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN
2828namespace torch_ext
2929{
3030
31- void indexer_k_cache_scatter_op (th::Tensor const & k_fp8_bytes , th::Tensor const & k_scale_bytes , th::Tensor& k_cache,
32- th::Tensor const & slot_mapping_fp8, th::Tensor const & slot_mapping_scale)
31+ void indexer_k_cache_scatter_op (th::Tensor const & k_fp8 , th::Tensor const & k_scale , th::Tensor& k_cache,
32+ th::Tensor const & slot_mapping_fp8, th::Tensor const & slot_mapping_scale, int64_t num_tokens )
3333{
34- // Validate all tensors are CUDA tensors
35- TORCH_CHECK (k_fp8_bytes.is_cuda () && k_scale_bytes.is_cuda () && k_cache.is_cuda () && slot_mapping_fp8.is_cuda ()
34+ // k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8
35+ // k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes
36+ // slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used
37+ // k_cache: [num_blocks, block_size, 1, per_token_size] uint8
38+
39+ TORCH_CHECK (k_fp8.is_cuda () && k_scale.is_cuda () && k_cache.is_cuda () && slot_mapping_fp8.is_cuda ()
3640 && slot_mapping_scale.is_cuda (),
3741 " All tensors must be CUDA tensors" );
3842
3943 // Validate tensor dimensions
40- TORCH_CHECK (k_fp8_bytes.dim () == 2 , " k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]" );
41- TORCH_CHECK (k_scale_bytes.dim () == 2 , " k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]" );
42- TORCH_CHECK (slot_mapping_fp8.dim () == 1 , " slot_mapping_fp8 must be a 1D Tensor [num_tokens]" );
43- TORCH_CHECK (slot_mapping_scale.dim () == 1 , " slot_mapping_scale must be a 1D Tensor [num_tokens]" );
44-
45- // Enforce k_cache is 4D tensor
46- TORCH_CHECK (k_cache.dim () == 4 ,
47- " k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions" ,
44+ TORCH_CHECK (k_fp8.dim () == 2 , " k_fp8 must be 2D [num_tokens, head_dim]" );
45+ TORCH_CHECK (k_scale.dim () == 2 , " k_scale must be 2D [num_tokens, scale_elements]" );
46+ TORCH_CHECK (slot_mapping_fp8.dim () == 1 , " slot_mapping_fp8 must be 1D [num_tokens]" );
47+ TORCH_CHECK (slot_mapping_scale.dim () == 1 , " slot_mapping_scale must be 1D [num_tokens]" );
48+ TORCH_CHECK (k_cache.dim () == 4 , " k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims" ,
4849 static_cast <int >(k_cache.dim ()));
4950
50- // Validate tensor dtypes
51- TORCH_CHECK (k_fp8_bytes.scalar_type () == torch::kUInt8 , " k_fp8_bytes must be uint8" );
52- TORCH_CHECK (k_scale_bytes.scalar_type () == torch::kUInt8 , " k_scale_bytes must be uint8" );
51+ // Validate tensor dtypes — reinterpret_cast below assumes specific element sizes
52+ TORCH_CHECK (k_fp8.element_size () == 1 , " k_fp8 must have 1-byte elements (e.g. FP8), got %d" , k_fp8.element_size ());
53+ TORCH_CHECK (k_scale.element_size () == 4 , " k_scale must have 4-byte elements (e.g. float32), got %d" ,
54+ k_scale.element_size ());
5355 TORCH_CHECK (slot_mapping_fp8.scalar_type () == torch::kInt64 , " slot_mapping_fp8 must be int64" );
5456 TORCH_CHECK (slot_mapping_scale.scalar_type () == torch::kInt64 , " slot_mapping_scale must be int64" );
5557
56- // Validate tensor shapes are consistent
57- auto num_tokens = static_cast <int32_t >(k_fp8_bytes.size (0 ));
58- TORCH_CHECK (
59- k_scale_bytes.size (0 ) == num_tokens, " k_scale_bytes first dimension must equal k_fp8_bytes first dimension" );
60- TORCH_CHECK (slot_mapping_fp8.size (0 ) == num_tokens, " slot_mapping_fp8 length must equal num_tokens" );
61- TORCH_CHECK (slot_mapping_scale.size (0 ) == num_tokens, " slot_mapping_scale length must equal num_tokens" );
62-
63- // Validate tensors are contiguous (except k_cache which may be non-contiguous)
64- TORCH_CHECK (k_fp8_bytes.is_contiguous (), " k_fp8_bytes must be contiguous" );
65- TORCH_CHECK (k_scale_bytes.is_contiguous (), " k_scale_bytes must be contiguous" );
66- // k_cache can be non-contiguous - we handle this via strides
58+ TORCH_CHECK (k_fp8.is_contiguous (), " k_fp8 must be contiguous" );
59+ TORCH_CHECK (k_scale.is_contiguous (), " k_scale must be contiguous" );
6760 TORCH_CHECK (slot_mapping_fp8.is_contiguous (), " slot_mapping_fp8 must be contiguous" );
6861 TORCH_CHECK (slot_mapping_scale.is_contiguous (), " slot_mapping_scale must be contiguous" );
6962
70- int32_t head_dim = static_cast <int32_t >(k_fp8_bytes.size (1 )); // head_dim = quant_block_size = 128
71- int32_t scale_size = static_cast <int32_t >(k_scale_bytes.size (1 )); // scale_size = 4 bytes
72-
73- int32_t cache_dim_0 = static_cast <int32_t >(k_cache.size (0 )); // num_blocks
74- int32_t cache_dim_1 = static_cast <int32_t >(k_cache.size (1 )); // block_size
75- int32_t cache_dim_2 = static_cast <int32_t >(k_cache.size (2 )); // num_kv_heads
76- int32_t cache_dim_3 = static_cast <int32_t >(k_cache.size (3 )); // per_token_size
77-
78- // Validation for indexer k cache pool for DeepSeek-V3.2 constraints
79- TORCH_CHECK (cache_dim_2 == 1 , " k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d" , cache_dim_2);
80- TORCH_CHECK (head_dim == 128 , " k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d" , head_dim);
81- TORCH_CHECK (scale_size == 4 , " k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d" , scale_size);
82-
83- int64_t cache_stride_0 = static_cast <int64_t >(k_cache.stride (0 ));
84- int64_t cache_stride_1 = static_cast <int64_t >(k_cache.stride (1 ));
85- int64_t cache_stride_2 = static_cast <int64_t >(k_cache.stride (2 ));
86- int64_t cache_stride_3 = static_cast <int64_t >(k_cache.stride (3 ));
87-
88- auto stream = at::cuda::getCurrentCUDAStream (k_fp8_bytes.get_device ());
89-
90- tk::invokeIndexerKCacheScatter (k_fp8_bytes.data_ptr <uint8_t >(), k_scale_bytes.data_ptr <uint8_t >(),
91- k_cache.data_ptr <uint8_t >(), slot_mapping_fp8.data_ptr <int64_t >(), slot_mapping_scale.data_ptr <int64_t >(),
92- num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
93- cache_stride_1, cache_stride_2, cache_stride_3, stream);
63+ // FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes.
64+ int32_t const head_dim = static_cast <int32_t >(k_fp8.size (1 ));
65+ // Scale size in bytes: num_scale_elements * bytes_per_element.
66+ int32_t const scale_size = static_cast <int32_t >(k_scale.size (1 )) * static_cast <int32_t >(k_scale.element_size ());
67+
68+ int32_t const cache_dim_0 = static_cast <int32_t >(k_cache.size (0 ));
69+ int32_t const cache_dim_1 = static_cast <int32_t >(k_cache.size (1 ));
70+ int32_t const cache_dim_2 = static_cast <int32_t >(k_cache.size (2 ));
71+ int32_t const cache_dim_3 = static_cast <int32_t >(k_cache.size (3 ));
72+
73+ TORCH_CHECK (cache_dim_2 == 1 , " k_cache dimension 2 must be 1, got %d" , cache_dim_2);
74+ TORCH_CHECK (head_dim == 128 , " k_fp8 head_dim must be 128, got %d" , head_dim);
75+ TORCH_CHECK (scale_size == 4 , " k_scale scale_size must be 4 bytes, got %d" , scale_size);
76+
77+ int64_t const cache_stride_0 = static_cast <int64_t >(k_cache.stride (0 ));
78+ int64_t const cache_stride_1 = static_cast <int64_t >(k_cache.stride (1 ));
79+ int64_t const cache_stride_2 = static_cast <int64_t >(k_cache.stride (2 ));
80+ int64_t const cache_stride_3 = static_cast <int64_t >(k_cache.stride (3 ));
81+
82+ auto stream = at::cuda::getCurrentCUDAStream (k_fp8.get_device ());
83+
84+ // Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr.
85+ // For slot mappings, use data_ptr directly — only the first num_tokens entries are read.
86+ tk::invokeIndexerKCacheScatter (reinterpret_cast <uint8_t const *>(k_fp8.data_ptr ()),
87+ reinterpret_cast <uint8_t const *>(k_scale.data_ptr ()), k_cache.data_ptr <uint8_t >(),
88+ slot_mapping_fp8.data_ptr <int64_t >(), slot_mapping_scale.data_ptr <int64_t >(), static_cast <int32_t >(num_tokens),
89+ head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1,
90+ cache_stride_2, cache_stride_3, stream);
9491}
9592
9693} // namespace torch_ext
@@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END
10097TORCH_LIBRARY_FRAGMENT (trtllm, m)
10198{
10299 m.def (
103- " indexer_k_cache_scatter_op(Tensor k_fp8_bytes , Tensor k_scale_bytes , Tensor(a!) k_cache, "
104- " Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()" );
100+ " indexer_k_cache_scatter_op(Tensor k_fp8 , Tensor k_scale , Tensor(a!) k_cache, "
101+ " Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens ) -> ()" );
105102}
106103
107104TORCH_LIBRARY_IMPL (trtllm, CUDA, m)
0 commit comments