Skip to content

Commit 44cdcab

Browse files
committed
[https://nvbugs/5983390][perf] Reduce host overhead in DSA MLA attention path
Pass pre-computed num_contexts/num_ctx_tokens to thop::attention and trtllm_gen_attention to eliminate per-layer sum().item() calls that recompute batch structure from host_request_types/host_context_lengths. Move view/slice/reinterpret ops from Python _update_k_cache into the C++ indexer_k_cache_scatter_op kernel: accept original k_fp8 (FP8) and k_scale (float32) tensors directly with num_tokens, avoiding per-layer torch.empty, view, as_strided and slice overhead on the host. Signed-off-by: Jin Li <[email protected]>
1 parent 5fd0517 commit 44cdcab

8 files changed

Lines changed: 93 additions & 98 deletions

File tree

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ void initBindings(nb::module_& m)
7070
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
7171
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
7272
nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt,
73-
nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation",
73+
nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("opt_num_contexts") = std::nullopt,
74+
nb::arg("opt_num_ctx_tokens") = std::nullopt, "Multi-head attention operation",
7475
nb::call_guard<nb::gil_scoped_release>());
7576

7677
m.def(

cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp

Lines changed: 43 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,69 +28,56 @@ TRTLLM_NAMESPACE_BEGIN
2828
namespace torch_ext
2929
{
3030

31-
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
32-
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
31+
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache,
32+
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens)
3333
{
34-
// Validate all tensors are CUDA tensors
35-
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
34+
// k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8
35+
// k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes
36+
// slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used
37+
// k_cache: [num_blocks, block_size, 1, per_token_size] uint8
38+
39+
TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
3640
&& slot_mapping_scale.is_cuda(),
3741
"All tensors must be CUDA tensors");
3842

39-
// Validate tensor dimensions
40-
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
41-
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
42-
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
43-
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");
44-
45-
// Enforce k_cache is 4D tensor
46-
TORCH_CHECK(k_cache.dim() == 4,
47-
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
43+
TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]");
44+
TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]");
45+
TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims",
4846
static_cast<int>(k_cache.dim()));
4947

50-
// Validate tensor dtypes
51-
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
52-
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
53-
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
54-
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");
55-
56-
// Validate tensor shapes are consistent
57-
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
58-
TORCH_CHECK(
59-
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
60-
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
61-
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");
62-
63-
// Validate tensors are contiguous (except k_cache which may be non-contiguous)
64-
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
65-
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
66-
// k_cache can be non-contiguous - we handle this via strides
48+
TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous");
49+
TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous");
6750
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
6851
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");
6952

70-
int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
71-
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes
72-
73-
int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
74-
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
75-
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
76-
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size
77-
78-
// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
79-
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
80-
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
81-
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);
82-
83-
int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
84-
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
85-
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
86-
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
87-
88-
auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());
89-
90-
tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
91-
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
92-
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
93-
cache_stride_1, cache_stride_2, cache_stride_3, stream);
53+
// FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes.
54+
int32_t const head_dim = static_cast<int32_t>(k_fp8.size(1));
55+
// float32 scale: each element is 4 bytes.
56+
int32_t const scale_size = static_cast<int32_t>(k_scale.size(1)) * 4;
57+
58+
int32_t const cache_dim_0 = static_cast<int32_t>(k_cache.size(0));
59+
int32_t const cache_dim_1 = static_cast<int32_t>(k_cache.size(1));
60+
int32_t const cache_dim_2 = static_cast<int32_t>(k_cache.size(2));
61+
int32_t const cache_dim_3 = static_cast<int32_t>(k_cache.size(3));
62+
63+
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2);
64+
TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim);
65+
TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size);
66+
67+
int64_t const cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
68+
int64_t const cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
69+
int64_t const cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
70+
int64_t const cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
71+
72+
auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device());
73+
74+
// Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr.
75+
// For slot mappings, use data_ptr directly — only the first num_tokens entries are read.
76+
tk::invokeIndexerKCacheScatter(reinterpret_cast<uint8_t const*>(k_fp8.data_ptr()),
77+
reinterpret_cast<uint8_t const*>(k_scale.data_ptr()), k_cache.data_ptr<uint8_t>(),
78+
slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(), static_cast<int32_t>(num_tokens),
79+
head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1,
80+
cache_stride_2, cache_stride_3, stream);
9481
}
9582

9683
} // namespace torch_ext
@@ -100,8 +87,8 @@ TRTLLM_NAMESPACE_END
10087
TORCH_LIBRARY_FRAGMENT(trtllm, m)
10188
{
10289
m.def(
103-
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
104-
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
90+
"indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, "
91+
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()");
10592
}
10693

10794
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
630630
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
631631
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
632632
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
633-
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits)
633+
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits,
634+
std::optional<int64_t> opt_num_contexts, std::optional<int64_t> opt_num_ctx_tokens)
634635
{
635636
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
636637
// Use these tensors to infer if the attention is using KV cache
@@ -833,19 +834,28 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
833834
}
834835
bool const is_gen_only = attn_input_type == AttentionInputType::GenerationOnly;
835836

836-
int32_t num_contexts = 0;
837-
// count context requests
838-
for (int32_t idx = 0; idx < num_seqs; idx++)
837+
int32_t num_contexts;
838+
if (opt_num_contexts.has_value())
839839
{
840-
if (request_types[idx] != RequestType::kCONTEXT)
840+
num_contexts = static_cast<int32_t>(opt_num_contexts.value());
841+
}
842+
else
843+
{
844+
num_contexts = 0;
845+
for (int32_t idx = 0; idx < num_seqs; idx++)
841846
{
842-
break;
847+
if (request_types[idx] != RequestType::kCONTEXT)
848+
{
849+
break;
850+
}
851+
++num_contexts;
843852
}
844-
++num_contexts;
845853
}
846854
int32_t const num_generations = num_seqs - num_contexts;
847855
int32_t const num_tokens = qkv_or_q.size(0);
848-
int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item<int32_t>();
856+
int32_t const num_ctx_tokens = opt_num_ctx_tokens.has_value()
857+
? static_cast<int32_t>(opt_num_ctx_tokens.value())
858+
: host_context_lengths.slice(0, 0, num_contexts).sum().item<int32_t>();
849859
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens;
850860
auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item<int32_t>();
851861
auto const gen_total_kv_len = host_total_kv_lens.index({1}).item<int32_t>();

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
7878
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
7979
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
8080
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata = std::nullopt,
81-
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt);
81+
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt,
82+
std::optional<int64_t> opt_num_contexts = std::nullopt, std::optional<int64_t> opt_num_ctx_tokens = std::nullopt);
8283

8384
struct KvCachePoolPointers
8485
{

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,34 +1356,18 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor,
13561356
if metadata.kv_cache_manager is None or metadata.slot_mapping_fp8 is None:
13571357
return
13581358

1359-
# [num_blocks, block_size, 1, per_token_size ]
13601359
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
13611360
self.layer_idx)
13621361

13631362
num_tokens = k_fp8.shape[0]
1364-
head_dim = k_fp8.shape[1]
1365-
scale_size = k_scale.shape[1] * 4 # Convert to bytes (float32 = 4 bytes)
1366-
1367-
# Convert to bytes: flatten first, then view as uint8, then reshape
1368-
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(
1369-
num_tokens, head_dim)
1370-
1371-
# k_scale: for single-element tensors, contiguous() may be no-op
1372-
# Fix stride(-1) for byte-level view
1373-
k_scale_flat = k_scale.view(-1)
1374-
if k_scale_flat.stride(-1) != 1:
1375-
k_scale_flat = torch.as_strided(k_scale_flat.contiguous(),
1376-
size=(k_scale_flat.numel(), ),
1377-
stride=(1, ))
1378-
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
1379-
num_tokens, scale_size)
1380-
1381-
# Use CUDA kernel to scatter FP8 and scale bytes into cache
1382-
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
1383-
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
1384-
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
1385-
k_cache, flat_indices_fp8,
1386-
flat_indices_scale)
1363+
1364+
# The C++ op reinterprets k_fp8 (FP8) and k_scale (float32) as raw
1365+
# bytes internally and only reads the first num_tokens entries from
1366+
# the slot mapping buffers, avoiding Python-side view/slice overhead.
1367+
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache,
1368+
metadata.slot_mapping_fp8,
1369+
metadata.slot_mapping_scale,
1370+
num_tokens)
13871371

13881372
def sparse_attn_indexer(
13891373
self,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,9 @@ def run(
407407
mla_bmm1_scale: Optional[torch.Tensor] = None,
408408
mla_bmm2_scale: Optional[torch.Tensor] = None,
409409
quant_q_buffer: Optional[torch.Tensor] = None,
410+
num_contexts: int = 0,
411+
num_generations: int = 0,
412+
num_ctx_tokens: int = 0,
410413
):
411414
"""
412415
Run the attention operation.
@@ -639,6 +642,9 @@ def run(
639642
self.quant_config,
640643
self.kv_cache_manager,
641644
global_layer_idx=self.global_layer_idx,
645+
num_contexts=num_contexts,
646+
num_generations=num_generations,
647+
num_ctx_tokens=num_ctx_tokens,
642648
)
643649
else:
644650
thop.attention(
@@ -722,6 +728,8 @@ def run(
722728
quant_q_buffer,
723729
self.flash_mla_tile_scheduler_metadata,
724730
self.flash_mla_num_splits,
731+
num_contexts,
732+
num_ctx_tokens,
725733
)
726734

727735
if self.print_skip_softmax_stat:
@@ -2049,7 +2057,10 @@ def forward(
20492057
fmha_scheduler_counter=fmha_scheduler_counter,
20502058
mla_bmm1_scale=mla_bmm1_scale,
20512059
mla_bmm2_scale=mla_bmm2_scale,
2052-
quant_q_buffer=quant_q_buffer)
2060+
quant_q_buffer=quant_q_buffer,
2061+
num_contexts=metadata.num_contexts,
2062+
num_generations=metadata.num_generations,
2063+
num_ctx_tokens=metadata.num_ctx_tokens)
20532064

20542065
if output_sf is None:
20552066
return output

tensorrt_llm/_torch/attention_backend/trtllm_gen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,9 @@ def trtllm_gen_attention(
15621562
quant_config: Optional[QuantConfig],
15631563
kv_cache_manager: Optional[KVCacheManager],
15641564
global_layer_idx: Optional[int] = None,
1565+
num_contexts: int = 0,
1566+
num_generations: int = 0,
1567+
num_ctx_tokens: int = 0,
15651568
) -> None:
15661569
"""
15671570
TrtLLM-Gen attention using flashinfer backend.
@@ -1691,9 +1694,6 @@ def trtllm_gen_attention(
16911694
if attention_input_type is not None:
16921695
attn_input_type = AttentionInputType(attention_input_type)
16931696

1694-
num_contexts, num_generations = _parse_request_types(host_request_types)
1695-
1696-
num_ctx_tokens = int(host_context_lengths[:num_contexts].sum()) if num_contexts > 0 else 0
16971697
num_gen_tokens = num_tokens - num_ctx_tokens
16981698

16991699
# Prepare Workspace

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def test_indexer_k_cache_scatter_custom_op():
703703
dtype=torch.bfloat16)
704704
k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original)
705705

706-
# Prepare byte-level data
706+
# Prepare byte-level data for the Python reference path
707707
scale_size = k_scale.shape[1] * 4
708708
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim)
709709
k_scale_flat = k_scale.view(-1)
@@ -742,9 +742,10 @@ def test_indexer_k_cache_scatter_custom_op():
742742

743743
# ========== Path 1: CUDA Kernel ==========
744744
print(f"\n=== Path 1: CUDA Kernel ===")
745-
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
746-
k_cache_cuda, flat_indices_fp8,
747-
flat_indices_scale)
745+
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache_cuda,
746+
metadata.slot_mapping_fp8,
747+
metadata.slot_mapping_scale,
748+
num_tokens)
748749
torch.cuda.synchronize()
749750
print(f"✓ CUDA kernel completed")
750751

0 commit comments

Comments
 (0)