Skip to content

Commit ae84aad

Browse files
authored
[https://nvbugs/5983390][perf] Reduce host overhead in DSA MLA attent… (#12631)
Signed-off-by: Jin Li <[email protected]>
1 parent b496348 commit ae84aad

9 files changed

Lines changed: 95 additions & 128 deletions

File tree

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ void initBindings(nb::module_& m)
7070
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
7171
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
7272
nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt,
73-
nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation",
74-
nb::call_guard<nb::gil_scoped_release>());
73+
nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0,
74+
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());
7575

7676
m.def(
7777
"get_helix_workspace_size_per_rank",

cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN
2828
namespace torch_ext
2929
{
3030

31-
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
32-
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
31+
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache,
32+
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens)
3333
{
34-
// Validate all tensors are CUDA tensors
35-
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
34+
// k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8
35+
// k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes
36+
// slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used
37+
// k_cache: [num_blocks, block_size, 1, per_token_size] uint8
38+
39+
TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
3640
&& slot_mapping_scale.is_cuda(),
3741
"All tensors must be CUDA tensors");
3842

3943
// Validate tensor dimensions
40-
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
41-
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
42-
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
43-
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");
44-
45-
// Enforce k_cache is 4D tensor
46-
TORCH_CHECK(k_cache.dim() == 4,
47-
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
44+
TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]");
45+
TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]");
46+
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be 1D [num_tokens]");
47+
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be 1D [num_tokens]");
48+
TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims",
4849
static_cast<int>(k_cache.dim()));
4950

50-
// Validate tensor dtypes
51-
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
52-
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
51+
// Validate tensor dtypes — reinterpret_cast below assumes specific element sizes
52+
TORCH_CHECK(k_fp8.element_size() == 1, "k_fp8 must have 1-byte elements (e.g. FP8), got %d", k_fp8.element_size());
53+
TORCH_CHECK(k_scale.element_size() == 4, "k_scale must have 4-byte elements (e.g. float32), got %d",
54+
k_scale.element_size());
5355
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
5456
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");
5557

56-
// Validate tensor shapes are consistent
57-
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
58-
TORCH_CHECK(
59-
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
60-
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
61-
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");
62-
63-
// Validate tensors are contiguous (except k_cache which may be non-contiguous)
64-
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
65-
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
66-
// k_cache can be non-contiguous - we handle this via strides
58+
TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous");
59+
TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous");
6760
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
6861
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");
6962

70-
int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
71-
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes
72-
73-
int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
74-
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
75-
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
76-
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size
77-
78-
// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
79-
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
80-
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
81-
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);
82-
83-
int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
84-
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
85-
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
86-
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
87-
88-
auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());
89-
90-
tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
91-
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
92-
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
93-
cache_stride_1, cache_stride_2, cache_stride_3, stream);
63+
// FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes.
64+
int32_t const head_dim = static_cast<int32_t>(k_fp8.size(1));
65+
// Scale size in bytes: num_scale_elements * bytes_per_element.
66+
int32_t const scale_size = static_cast<int32_t>(k_scale.size(1)) * static_cast<int32_t>(k_scale.element_size());
67+
68+
int32_t const cache_dim_0 = static_cast<int32_t>(k_cache.size(0));
69+
int32_t const cache_dim_1 = static_cast<int32_t>(k_cache.size(1));
70+
int32_t const cache_dim_2 = static_cast<int32_t>(k_cache.size(2));
71+
int32_t const cache_dim_3 = static_cast<int32_t>(k_cache.size(3));
72+
73+
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2);
74+
TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim);
75+
TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size);
76+
77+
int64_t const cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
78+
int64_t const cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
79+
int64_t const cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
80+
int64_t const cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));
81+
82+
auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device());
83+
84+
// Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr.
85+
// For slot mappings, use data_ptr directly — only the first num_tokens entries are read.
86+
tk::invokeIndexerKCacheScatter(reinterpret_cast<uint8_t const*>(k_fp8.data_ptr()),
87+
reinterpret_cast<uint8_t const*>(k_scale.data_ptr()), k_cache.data_ptr<uint8_t>(),
88+
slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(), static_cast<int32_t>(num_tokens),
89+
head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1,
90+
cache_stride_2, cache_stride_3, stream);
9491
}
9592

9693
} // namespace torch_ext
@@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END
10097
TORCH_LIBRARY_FRAGMENT(trtllm, m)
10198
{
10299
m.def(
103-
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
104-
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
100+
"indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, "
101+
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()");
105102
}
106103

107104
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
630630
std::optional<torch::Tensor> cu_q_seqlens, std::optional<torch::Tensor> cu_kv_seqlens,
631631
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
632632
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
633-
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits)
633+
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata, std::optional<torch::Tensor> flash_mla_num_splits,
634+
int64_t num_contexts, int64_t num_ctx_tokens)
634635
{
635636
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
636637
// Use these tensors to infer if the attention is using KV cache
@@ -833,20 +834,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
833834
}
834835
bool const is_gen_only = attn_input_type == AttentionInputType::GenerationOnly;
835836

836-
int32_t num_contexts = 0;
837-
// count context requests
838-
for (int32_t idx = 0; idx < num_seqs; idx++)
839-
{
840-
if (request_types[idx] != RequestType::kCONTEXT)
841-
{
842-
break;
843-
}
844-
++num_contexts;
845-
}
846-
int32_t const num_generations = num_seqs - num_contexts;
837+
int32_t const num_generations = num_seqs - static_cast<int32_t>(num_contexts);
847838
int32_t const num_tokens = qkv_or_q.size(0);
848-
int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item<int32_t>();
849-
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens;
839+
int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - static_cast<int32_t>(num_ctx_tokens);
850840
auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item<int32_t>();
851841
auto const gen_total_kv_len = host_total_kv_lens.index({1}).item<int32_t>();
852842

cpp/tensorrt_llm/thop/attentionOp.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
7878
std::optional<torch::Tensor> fmha_scheduler_counter, std::optional<torch::Tensor> mla_bmm1_scale,
7979
std::optional<torch::Tensor> mla_bmm2_scale, std::optional<torch::Tensor> quant_q_buffer,
8080
std::optional<torch::Tensor> flash_mla_tile_scheduler_metadata = std::nullopt,
81-
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt);
81+
std::optional<torch::Tensor> flash_mla_num_splits = std::nullopt, int64_t num_contexts = 0,
82+
int64_t num_ctx_tokens = 0);
8283

8384
struct KvCachePoolPointers
8485
{

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,34 +1458,18 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor,
14581458
if metadata.kv_cache_manager is None or metadata.slot_mapping_fp8 is None:
14591459
return
14601460

1461-
# [num_blocks, block_size, 1, per_token_size ]
14621461
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
14631462
self.layer_idx)
14641463

14651464
num_tokens = k_fp8.shape[0]
1466-
head_dim = k_fp8.shape[1]
1467-
scale_size = k_scale.shape[1] * 4 # Convert to bytes (float32 = 4 bytes)
1468-
1469-
# Convert to bytes: flatten first, then view as uint8, then reshape
1470-
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(
1471-
num_tokens, head_dim)
1472-
1473-
# k_scale: for single-element tensors, contiguous() may be no-op
1474-
# Fix stride(-1) for byte-level view
1475-
k_scale_flat = k_scale.view(-1)
1476-
if k_scale_flat.stride(-1) != 1:
1477-
k_scale_flat = torch.as_strided(k_scale_flat.contiguous(),
1478-
size=(k_scale_flat.numel(), ),
1479-
stride=(1, ))
1480-
k_scale_bytes = k_scale_flat.view(torch.uint8).view(
1481-
num_tokens, scale_size)
1482-
1483-
# Use CUDA kernel to scatter FP8 and scale bytes into cache
1484-
flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens]
1485-
flat_indices_scale = metadata.slot_mapping_scale[:num_tokens]
1486-
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
1487-
k_cache, flat_indices_fp8,
1488-
flat_indices_scale)
1465+
1466+
# The C++ op reinterprets k_fp8 (FP8) and k_scale (float32) as raw
1467+
# bytes internally and only reads the first num_tokens entries from
1468+
# the slot mapping buffers, avoiding Python-side view/slice overhead.
1469+
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache,
1470+
metadata.slot_mapping_fp8,
1471+
metadata.slot_mapping_scale,
1472+
num_tokens)
14891473

14901474
def sparse_attn_indexer(
14911475
self,

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ def run(
406406
mla_bmm1_scale: Optional[torch.Tensor] = None,
407407
mla_bmm2_scale: Optional[torch.Tensor] = None,
408408
quant_q_buffer: Optional[torch.Tensor] = None,
409+
num_contexts: int = 0,
410+
num_ctx_tokens: int = 0,
409411
):
410412
"""
411413
Run the attention operation.
@@ -652,6 +654,8 @@ def run(
652654
quant_q_buffer,
653655
self.quant_config,
654656
self.kv_cache_manager,
657+
num_contexts,
658+
num_ctx_tokens,
655659
global_layer_idx=self.global_layer_idx,
656660
)
657661
else:
@@ -736,6 +740,8 @@ def run(
736740
quant_q_buffer,
737741
self.flash_mla_tile_scheduler_metadata,
738742
self.flash_mla_num_splits,
743+
num_contexts=num_contexts,
744+
num_ctx_tokens=num_ctx_tokens,
739745
)
740746

741747
if self.print_skip_softmax_stat:
@@ -2087,7 +2093,9 @@ def forward(
20872093
fmha_scheduler_counter=fmha_scheduler_counter,
20882094
mla_bmm1_scale=mla_bmm1_scale,
20892095
mla_bmm2_scale=mla_bmm2_scale,
2090-
quant_q_buffer=quant_q_buffer)
2096+
quant_q_buffer=quant_q_buffer,
2097+
num_contexts=metadata.num_contexts,
2098+
num_ctx_tokens=metadata.num_ctx_tokens)
20912099

20922100
if output_sf is None:
20932101
return output

tensorrt_llm/_torch/attention_backend/trtllm_gen.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,23 +1437,6 @@ def run_mla_generation(self, params: EnqueueGenerationParams) -> None:
14371437
params.context_buf.copy_(mla_out.reshape_as(params.context_buf))
14381438

14391439

1440-
def _parse_request_types(host_request_types: torch.Tensor) -> Tuple[int, int]:
1441-
"""
1442-
Parse request types to count context and generation requests.
1443-
1444-
Args:
1445-
host_request_types: Request types tensor (0=context, 1=generation).
1446-
num_seqs: Total number of sequences.
1447-
1448-
Returns:
1449-
Tuple of (num_contexts, num_generations).
1450-
"""
1451-
1452-
num_generations = host_request_types.sum().item()
1453-
num_contexts = host_request_types.size(0) - num_generations
1454-
return num_contexts, num_generations
1455-
1456-
14571440
def is_supported(
14581441
q: torch.Tensor,
14591442
num_heads: int,
@@ -1636,6 +1619,8 @@ def trtllm_gen_attention(
16361619
quant_q_buffer: Optional[torch.Tensor],
16371620
quant_config: Optional[QuantConfig],
16381621
kv_cache_manager: Optional[KVCacheManager],
1622+
num_contexts: int,
1623+
num_ctx_tokens: int,
16391624
global_layer_idx: Optional[int] = None,
16401625
) -> None:
16411626
"""
@@ -1766,20 +1751,10 @@ def trtllm_gen_attention(
17661751
if attention_input_type is not None:
17671752
attn_input_type = AttentionInputType(attention_input_type)
17681753

1769-
num_contexts, num_generations = _parse_request_types(host_request_types)
1770-
17711754
is_gen_only = attn_input_type == AttentionInputType.generation_only
1772-
is_ctx_only = attn_input_type == AttentionInputType.context_only
1773-
1774-
if is_gen_only:
1775-
num_ctx_tokens = 0
1776-
num_gen_tokens = num_tokens
1777-
elif is_ctx_only:
1778-
num_ctx_tokens = num_tokens
1779-
num_gen_tokens = 0
1780-
else:
1781-
num_ctx_tokens = int(host_context_lengths[:num_contexts].sum()) if num_contexts > 0 else 0
1782-
num_gen_tokens = num_tokens - num_ctx_tokens
1755+
1756+
num_generations = host_request_types.size(0) - num_contexts
1757+
num_gen_tokens = num_tokens - num_ctx_tokens
17831758

17841759
# Prepare Workspace
17851760
# Use upper-bound token counts for workspace sizing to avoid repeated

tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def __init__(self):
8282
# keeping a separate copy here since we sometimes have to overwrite the original values
8383
self.host_past_kv_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned
8484
self.host_context_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned
85+
# Batch counts for thop.attention (updated every forward in plan_host)
86+
self.num_contexts: int = 0
87+
self.num_ctx_tokens: int = 0
8588
# Persistent block_offsets buffer for CUDA graph compatibility.
8689
# Pre-allocated to max size so the tensor address is stable across replays.
8790
self.block_offsets: Optional[torch.Tensor] = None
@@ -171,6 +174,10 @@ def plan_host(
171174
"""
172175
num_seq = num_prefill + num_decode
173176

177+
# Batch counts for thop.attention
178+
self.num_contexts = num_prefill
179+
self.num_ctx_tokens = int(seq_len_host[:num_prefill].sum()) if num_prefill > 0 else 0
180+
174181
# host_request_types: 0 = prefill (context), 1 = decode (generation)
175182
self.host_request_types[:num_prefill].fill_(0)
176183
self.host_request_types[num_prefill:num_seq].fill_(1)
@@ -500,6 +507,10 @@ def trtllm_mha_with_cache(
500507
None, # mla_bmm1_scale
501508
None, # mla_bmm2_scale
502509
None, # quant_q_buffer
510+
None, # flash_mla_tile_scheduler_metadata
511+
None, # flash_mla_num_splits
512+
num_contexts=_GlobalTrtllmPlanner.num_contexts,
513+
num_ctx_tokens=_GlobalTrtllmPlanner.num_ctx_tokens,
503514
)
504515

505516
if out is not None:

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def test_indexer_k_cache_scatter_custom_op():
716716
dtype=torch.bfloat16)
717717
k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original)
718718

719-
# Prepare byte-level data
719+
# Prepare byte-level data for the Python reference path
720720
scale_size = k_scale.shape[1] * 4
721721
k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim)
722722
k_scale_flat = k_scale.view(-1)
@@ -755,9 +755,10 @@ def test_indexer_k_cache_scatter_custom_op():
755755

756756
# ========== Path 1: CUDA Kernel ==========
757757
print("\n=== Path 1: CUDA Kernel ===")
758-
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes,
759-
k_cache_cuda, flat_indices_fp8,
760-
flat_indices_scale)
758+
torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache_cuda,
759+
metadata.slot_mapping_fp8,
760+
metadata.slot_mapping_scale,
761+
num_tokens)
761762
torch.cuda.synchronize()
762763
print("✓ CUDA kernel completed")
763764

0 commit comments

Comments
 (0)