|
11 | 11 | MLAParams, PositionalEmbeddingParams) |
12 | 12 | from tensorrt_llm._torch.attention_backend.trtllm import ( |
13 | 13 | TrtllmAttention, TrtllmAttentionMetadata) |
| 14 | +from tensorrt_llm._torch.cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE |
14 | 15 | from tensorrt_llm._torch.distributed.ops import allgather |
15 | 16 | from tensorrt_llm._torch.modules.layer_norm import LayerNorm |
16 | 17 | from tensorrt_llm._torch.modules.linear import Linear |
@@ -1025,8 +1026,18 @@ def __init__(self, |
1025 | 1026 | self.scale_fmt = "ue8m0" |
1026 | 1027 | self.aux_stream = aux_stream |
1027 | 1028 | self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] |
| 1029 | + self.use_cute_dsl_topk = (sparse_attention_config.use_cute_dsl_topk |
| 1030 | + and IS_CUTLASS_DSL_AVAILABLE) |
1028 | 1031 | self.weight_scale_factor = self.softmax_scale * self.n_heads**-0.5 |
1029 | 1032 |
|
| 1033 | + if self.use_cute_dsl_topk and layer_idx == 0: |
| 1034 | + from tensorrt_llm._torch.custom_ops import cute_dsl_custom_ops |
| 1035 | + |
| 1036 | + # the dtype of topk input tensor, which is float32 now. |
| 1037 | + # Note, need to update it if the dtype of topk input tensor is changed. |
| 1038 | + cute_dsl_custom_ops.warmup_cute_dsl_indexer_topk( |
| 1039 | + dtype=torch.float32, top_k=self.index_topk) |
| 1040 | + |
1030 | 1041 | def post_load_weights(self): |
1031 | 1042 | """Fuse wk + weights_proj into single FP32 weight for cuBLAS GEMM (TF32 on Ampere+).""" |
1032 | 1043 | # wk: [head_dim, hidden_size] + weights_proj: [n_heads, hidden_size] |
@@ -1644,10 +1655,23 @@ def sparse_attn_indexer( |
1644 | 1655 | # This is because rowEnd = seq_len - next_n + offset + 1 |
1645 | 1656 | gen_kv_lens_cuda = metadata.kv_lens_cuda_runtime[ |
1646 | 1657 | num_contexts:num_contexts + num_generations] |
1647 | | - torch.ops.trtllm.indexer_topk_decode( |
1648 | | - logits_decode, gen_kv_lens_cuda, |
1649 | | - topk_indices_buffer[num_ctx_tokens:num_ctx_tokens + |
1650 | | - num_gen_tokens, :], next_n) |
| 1658 | + # CuTE DSL top-k allocates O(num_gen_tokens * kv_len) global |
| 1659 | + # memory. Beyond 256 tokens the extra memory becomes significant, |
| 1660 | + # so we cap it at 256 for now and fall back to the CUDA C++ |
| 1661 | + # indexer_topk_decode. This limit can be removed if GPU memory |
| 1662 | + # is not a bottleneck. |
| 1663 | + if self.use_cute_dsl_topk and num_gen_tokens <= 256: |
| 1664 | + torch.ops.trtllm.cute_dsl_indexer_topk_decode( |
| 1665 | + logits_decode, gen_kv_lens_cuda, |
| 1666 | + topk_indices_buffer[num_ctx_tokens:num_ctx_tokens + |
| 1667 | + num_gen_tokens, :], self.index_topk, |
| 1668 | + next_n) |
| 1669 | + else: |
| 1670 | + torch.ops.trtllm.indexer_topk_decode( |
| 1671 | + logits_decode, gen_kv_lens_cuda, |
| 1672 | + topk_indices_buffer[num_ctx_tokens:num_ctx_tokens + |
| 1673 | + num_gen_tokens, :], next_n, |
| 1674 | + self.index_topk) |
1651 | 1675 | else: |
1652 | 1676 | # padded |
1653 | 1677 | positions = torch.arange( |
|
0 commit comments