Skip to content

Commit e940e58

Browse files
authored
[TRTLLM-10407][perf] Enable CuteDSL indexer_top_k in model (#12236)
Signed-off-by: Mindy Li <[email protected]>
1 parent 32be345 commit e940e58

8 files changed

Lines changed: 1977 additions & 287 deletions

File tree

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
MLAParams, PositionalEmbeddingParams)
1212
from tensorrt_llm._torch.attention_backend.trtllm import (
1313
TrtllmAttention, TrtllmAttentionMetadata)
14+
from tensorrt_llm._torch.cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
1415
from tensorrt_llm._torch.distributed.ops import allgather
1516
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
1617
from tensorrt_llm._torch.modules.linear import Linear
@@ -1025,8 +1026,18 @@ def __init__(self,
10251026
self.scale_fmt = "ue8m0"
10261027
self.aux_stream = aux_stream
10271028
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
1029+
self.use_cute_dsl_topk = (sparse_attention_config.use_cute_dsl_topk
1030+
and IS_CUTLASS_DSL_AVAILABLE)
10281031
self.weight_scale_factor = self.softmax_scale * self.n_heads**-0.5
10291032

1033+
if self.use_cute_dsl_topk and layer_idx == 0:
1034+
from tensorrt_llm._torch.custom_ops import cute_dsl_custom_ops
1035+
1036+
# the dtype of topk input tensor, which is float32 now.
1037+
# Note, need to update it if the dtype of topk input tensor is changed.
1038+
cute_dsl_custom_ops.warmup_cute_dsl_indexer_topk(
1039+
dtype=torch.float32, top_k=self.index_topk)
1040+
10301041
def post_load_weights(self):
10311042
"""Fuse wk + weights_proj into single FP32 weight for cuBLAS GEMM (TF32 on Ampere+)."""
10321043
# wk: [head_dim, hidden_size] + weights_proj: [n_heads, hidden_size]
@@ -1644,10 +1655,23 @@ def sparse_attn_indexer(
16441655
# This is because rowEnd = seq_len - next_n + offset + 1
16451656
gen_kv_lens_cuda = metadata.kv_lens_cuda_runtime[
16461657
num_contexts:num_contexts + num_generations]
1647-
torch.ops.trtllm.indexer_topk_decode(
1648-
logits_decode, gen_kv_lens_cuda,
1649-
topk_indices_buffer[num_ctx_tokens:num_ctx_tokens +
1650-
num_gen_tokens, :], next_n)
1658+
# CuTE DSL top-k allocates O(num_gen_tokens * kv_len) global
1659+
# memory. Beyond 256 tokens the extra memory becomes significant,
1660+
# so we cap it at 256 for now and fall back to the CUDA C++
1661+
# indexer_topk_decode. This limit can be removed if GPU memory
1662+
# is not a bottleneck.
1663+
if self.use_cute_dsl_topk and num_gen_tokens <= 256:
1664+
torch.ops.trtllm.cute_dsl_indexer_topk_decode(
1665+
logits_decode, gen_kv_lens_cuda,
1666+
topk_indices_buffer[num_ctx_tokens:num_ctx_tokens +
1667+
num_gen_tokens, :], self.index_topk,
1668+
next_n)
1669+
else:
1670+
torch.ops.trtllm.indexer_topk_decode(
1671+
logits_decode, gen_kv_lens_cuda,
1672+
topk_indices_buffer[num_ctx_tokens:num_ctx_tokens +
1673+
num_gen_tokens, :], next_n,
1674+
self.index_topk)
16511675
else:
16521676
# padded
16531677
positions = torch.arange(

0 commit comments

Comments
 (0)