Skip to content

Commit e57d83c

Browse files
authored
[TRTLLM-8768][chore] Fuse QK down_proj with indexer K + weight_proj for FP4 ckpt (#8771)
1 parent fdd9e4f commit e57d83c

File tree

4 files changed

+175
-105
lines changed

4 files changed

+175
-105
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def __init__(self,
629629
self.scale_fmt = "ue8m0"
630630
self.aux_stream = aux_stream
631631
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
632+
self.weight_scale_factor = self.softmax_scale * self.n_heads**-0.5
632633

633634
@staticmethod
634635
def prepare_one_prefill_chunk(
@@ -1105,65 +1106,86 @@ def sparse_attn_indexer(
11051106

11061107
return topk_indices_buffer
11071108

1109+
def weight_scale(self, hidden_states: torch.Tensor,
1110+
indexer_weights: Optional[torch.Tensor],
1111+
q_scale: torch.Tensor) -> torch.Tensor:
1112+
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
1113+
hidden_states)
1114+
weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor
1115+
# output weights is guaranteed to be float32 due to type promotion from q_scale (float32)
1116+
weights = weights.squeeze(-1)
1117+
return weights
1118+
11081119
@torch.inference_mode()
11091120
def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor,
11101121
metadata: DSAtrtllmAttentionMetadata,
1111-
position_ids: torch.Tensor):
1122+
position_ids: torch.Tensor, indexer_k: Optional[torch.Tensor],
1123+
indexer_weights: Optional[torch.Tensor]):
11121124
quant_block_size = metadata.kv_cache_manager.quant_block_size
11131125
assert quant_block_size == 128, "Only support quant_block_size = 128 for now"
11141126

1127+
if indexer_k is not None:
1128+
q, k = maybe_execute_in_parallel(
1129+
lambda: self.wq_b(
1130+
qr), # TODO: fuse wq_b and move this outside of the indexer
1131+
lambda: self.k_norm(indexer_k),
1132+
self.ln_events[0],
1133+
self.ln_events[1],
1134+
self.aux_stream,
1135+
)
1136+
else:
1137+
q, k = maybe_execute_in_parallel(
1138+
lambda: self.wq_b(qr),
1139+
lambda: self.k_norm(self.wk(hidden_states)),
1140+
self.ln_events[0],
1141+
self.ln_events[1],
1142+
self.aux_stream,
1143+
)
1144+
1145+
# q/k rope + possible fast_hadamard_transform
1146+
q = q.view(-1, self.n_heads, self.head_dim)
1147+
11151148
q, k = maybe_execute_in_parallel(
1116-
lambda: self.wq_b(qr),
1117-
lambda: self.wk(hidden_states),
1149+
lambda: torch.split(
1150+
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
1151+
lambda: torch.split(
1152+
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1),
11181153
self.ln_events[0],
11191154
self.ln_events[1],
11201155
self.aux_stream,
11211156
)
1122-
q = q.view(-1, self.n_heads, self.head_dim)
1123-
q_pe, q_nope = torch.split(
1124-
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
1125-
k = self.k_norm(k)
1126-
k_pe, k_nope = torch.split(
1127-
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
11281157

1129-
# k_pe needs unsqueeze to match n_heads
1158+
q_pe, q_nope = q
1159+
k_pe, k_nope = k
11301160
q_pe, k_pe = self.rotary_emb(position_ids, [q_pe, k_pe.unsqueeze(1)])
1131-
q = torch.cat([q_pe, q_nope], dim=-1)
1132-
# Remove head dimension (size 1) for MQA k
1133-
k = torch.cat([k_pe[:, 0, :], k_nope], dim=-1)
11341161

1135-
q, k = maybe_execute_in_parallel(
1136-
lambda: rotate_activation(q),
1137-
lambda: rotate_activation(k),
1138-
self.ln_events[0],
1139-
self.ln_events[1],
1140-
self.aux_stream,
1141-
)
1142-
# we only quant q here since k quant is fused with cache insertion
1143-
q = q.view(-1, self.head_dim)
1162+
k_pe = k_pe[:, 0, :]
1163+
1164+
def _prep_q_or_k(qk_pe, qk_nope):
1165+
q_or_k = torch.cat([qk_pe, qk_nope], dim=-1)
1166+
q_or_k = rotate_activation(q_or_k)
1167+
q_or_k = q_or_k.view(-1, self.head_dim)
1168+
q_or_k = fp8_utils.fp8_quantize_1x128_sf_transpose(
1169+
q_or_k, use_ue8m0=self.scale_fmt == "ue8m0")
1170+
return q_or_k
11441171

11451172
q, k = maybe_execute_in_parallel(
1146-
lambda: fp8_utils.fp8_quantize_1x128_sf_transpose(
1147-
q, use_ue8m0=self.scale_fmt == "ue8m0"),
1148-
lambda: fp8_utils.fp8_quantize_1x128_sf_transpose(
1149-
k, use_ue8m0=self.scale_fmt == "ue8m0"),
1173+
lambda: _prep_q_or_k(q_pe, q_nope),
1174+
lambda: _prep_q_or_k(k_pe, k_nope),
11501175
self.ln_events[0],
11511176
self.ln_events[1],
11521177
self.aux_stream,
11531178
)
1179+
11541180
q_fp8, q_scale = q
11551181
k_fp8, k_scale = k
11561182
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
11571183
q_scale = q_scale.view(-1, self.n_heads, 1)
11581184

1159-
weights = self.weights_proj(hidden_states)
1160-
weights = weights.unsqueeze(
1161-
-1) * q_scale * self.softmax_scale * self.n_heads**-0.5
1162-
weights = weights.squeeze(-1)
1163-
1185+
weights = self.weight_scale(hidden_states, indexer_weights, q_scale)
11641186
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
11651187
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
1166-
k_scale, weights.to(torch.float32))
1188+
k_scale, weights)
11671189

11681190

11691191
class DSATrtllmAttention(TrtllmAttention):

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
363363
[q_a_proj_scale, fused_a_scale], dim=0)
364364

365365
module.weight_scale.data.copy_(fused_a_scale)
366-
367-
module.weight.data.copy_(fused_a)
366+
# For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized
367+
# to include indexer weights, which is filled in post_load_weights.
368+
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
368369
elif names[-1] in params_map:
369370
module_weights = []
370371
for new_name in params_map[names[-1]]:
@@ -544,6 +545,79 @@ def __init__(
544545
use_custom_cublas_mm=True)
545546

546547

548+
class DeepseekV32Attention(MLA):
549+
550+
def __init__(
551+
self,
552+
model_config: ModelConfig[PretrainedConfig],
553+
layer_idx: Optional[int] = None,
554+
aux_stream: Optional[torch.cuda.Stream] = None,
555+
):
556+
config = model_config.pretrained_config
557+
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
558+
559+
# DSV3.2 nvfp4 ckpt has kv_a_proj_with_mqa module in bfloat16
560+
# TODO: check it more directly/robustly, e.g., indexer_weight_quant == fuseA_quant == indexer_quant
561+
if model_config.get_quant_config().quant_algo == QuantAlgo.NVFP4:
562+
self.fuse_a_indexer_k_weight = True
563+
else:
564+
self.fuse_a_indexer_k_weight = False
565+
566+
super().__init__(hidden_size=config.hidden_size,
567+
num_attention_heads=config.num_attention_heads,
568+
num_key_value_heads=config.num_key_value_heads,
569+
qk_rope_head_dim=config.qk_rope_head_dim,
570+
qk_nope_head_dim=config.qk_nope_head_dim,
571+
q_lora_rank=config.q_lora_rank,
572+
kv_lora_rank=config.kv_lora_rank,
573+
v_head_dim=config.v_head_dim,
574+
predicted_tokens_per_seq=predicted_tokens_per_seq,
575+
max_position_embeddings=config.max_position_embeddings,
576+
bias=False,
577+
pos_embd_params=PositionalEmbeddingParams(
578+
type=PositionEmbeddingType.yarn,
579+
rope=RopeParams.from_config(config),
580+
is_neox=False,
581+
),
582+
layer_idx=layer_idx,
583+
dtype=config.torch_dtype,
584+
config=model_config,
585+
aux_stream=aux_stream)
586+
587+
self.indexer = self.mqa.indexer
588+
589+
if self.fuse_a_indexer_k_weight:
590+
# For DeepseekV32, the kv_a_proj_with_mqa includes:
591+
# q_a_proj + kv_a_proj_with_mqa + indexer.wk + indexer.weights_proj
592+
self.kv_a_proj_with_mqa = DeepseekV3Linear(
593+
config.hidden_size,
594+
self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank +
595+
self.indexer.head_dim + self.indexer.n_heads,
596+
bias=False,
597+
dtype=config.torch_dtype,
598+
quant_config=model_config.get_quant_config(),
599+
skip_create_weights_in_init=model_config.
600+
skip_create_weights_in_init,
601+
use_custom_cublas_mm=True)
602+
603+
def post_load_weights(self):
604+
if self.fuse_a_indexer_k_weight:
605+
assert self.kv_a_proj_with_mqa.weight.data.dtype == self.indexer.wk.weight.data.dtype == self.indexer.weights_proj.weight.data.dtype, "all weights in kv_a_proj_with_mqa module must have matching dtype"
606+
# Copy indexer weights into the fused kv_a_proj_with_mqa module
607+
indexer_wk_weight = self.indexer.wk.weight.data
608+
offset = self.kv_lora_rank + self.qk_rope_head_dim + self.q_lora_rank
609+
self.kv_a_proj_with_mqa.weight.data[offset:offset +
610+
self.indexer.head_dim].copy_(
611+
indexer_wk_weight)
612+
offset += self.indexer.head_dim
613+
indexer_weights_proj_weight = self.indexer.weights_proj.weight.data
614+
self.kv_a_proj_with_mqa.weight.data[offset:offset +
615+
self.indexer.n_heads].copy_(
616+
indexer_weights_proj_weight)
617+
self.indexer.wk = None
618+
self.indexer.weights_proj = None
619+
620+
547621
class Deepseekv3RoutingImpl():
548622

549623
def __init__(
@@ -952,10 +1026,16 @@ def __init__(self,
9521026
#KVCacheManager only support 1 layer for separate draft engine
9531027
layer_idx_for_attention = layer_idx - model_config.pretrained_config.num_hidden_layers
9541028

955-
self.self_attn = DeepseekV3Attention(
956-
model_config,
957-
layer_idx=layer_idx_for_attention,
958-
aux_stream=aux_stream_dict[AuxStreamType.Attention])
1029+
if config.model_type == "deepseek_v32":
1030+
self.self_attn = DeepseekV32Attention(
1031+
model_config,
1032+
layer_idx=layer_idx_for_attention,
1033+
aux_stream=aux_stream_dict[AuxStreamType.Attention])
1034+
else:
1035+
self.self_attn = DeepseekV3Attention(
1036+
model_config,
1037+
layer_idx=layer_idx_for_attention,
1038+
aux_stream=aux_stream_dict[AuxStreamType.Attention])
9591039
self.enable_attention_dp = mapping.enable_attention_dp
9601040

9611041
self.mlp_tp_size = mapping.tp_size

tensorrt_llm/_torch/modules/attention.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -963,8 +963,6 @@ def yarn_get_mscale(scale=1, mscale=1):
963963
if not config.skip_create_weights_in_init:
964964
self.create_weights()
965965

966-
self.indexer = self.mqa.indexer if self.is_dsa else None
967-
968966
def create_weights(self):
969967
# self.mha/mqa has no weights but has states that are related to quant_config,
970968
# which could be modified after __init__
@@ -1234,9 +1232,21 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
12341232
if position_ids is not None:
12351233
position_ids = position_ids[..., :num_tokens]
12361234

1237-
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
1238-
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim], -1)
1235+
if self.fuse_a_indexer_k_weight:
1236+
q, compressed_kv, k_pe, indexer_k, indexer_weights = self.kv_a_proj_with_mqa(
1237+
hidden_states).split([
1238+
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim,
1239+
self.indexer.head_dim, self.indexer.n_heads
1240+
], -1)
1241+
else:
1242+
q, compressed_kv, k_pe = self.kv_a_proj_with_mqa(
1243+
hidden_states).split([
1244+
self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim
1245+
], -1)
1246+
indexer_k = None
1247+
indexer_weights = None
12391248

1249+
# TODO: possibly overlap/fuse q_a_rmsnorm + kv_a_rmsnorm + indexer.k_layernorm?
12401250
q, compressed_kv = maybe_execute_in_parallel(
12411251
lambda: self.q_a_layernorm(q),
12421252
lambda: self.kv_a_layernorm(compressed_kv),
@@ -1245,23 +1255,25 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor],
12451255
self.aux_stream,
12461256
)
12471257
qr = q
1248-
q, latent_cache = maybe_execute_in_parallel(
1249-
lambda: self.q_b_proj(q),
1250-
lambda: torch.concat([compressed_kv, k_pe], dim=-1),
1251-
self.ln_events[0],
1252-
self.ln_events[1],
1253-
self.aux_stream,
1258+
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
1259+
1260+
# TODO: fuse wq_b + (indexer) wlq here
1261+
q = self.q_b_proj(q)
1262+
# Indexer
1263+
topk_indices = self.indexer(
1264+
qr,
1265+
hidden_states,
1266+
attn_metadata,
1267+
position_ids,
1268+
indexer_k=indexer_k, # indexer K proj
1269+
indexer_weights=indexer_weights, # indexer weights proj
12541270
)
12551271

12561272
assert q.shape[
12571273
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
12581274

12591275
assert output is not None, "output must be provided"
12601276

1261-
# Indexer
1262-
topk_indices = self.indexer(qr, hidden_states, attn_metadata,
1263-
position_ids)
1264-
12651277
if num_contexts > 0:
12661278
q_ctx = q[:num_ctx_tokens, ...]
12671279
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]

tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def yarn_get_mscale(scale=1, mscale=1):
443443
dtype=dtype,
444444
config=model_config,
445445
).to(device)
446+
mla.indexer = mla.mqa.indexer.to(device)
446447
mla_layers.append(mla)
447448

448449
# Use the test layer
@@ -675,60 +676,20 @@ def yarn_get_mscale(scale=1, mscale=1):
675676
sum(batch_query_lens[:i + 1]) for i in range(len(batch_order))
676677
]
677678
num_ctx_tokens = sum(query_lens[i] for i in ctx_indices)
678-
679-
def create_causal_indices(req_indices, cache_offset_start=0):
680-
"""Helper to create causal attention indices with padding."""
681-
indices = []
682-
kv_offset = cache_offset_start
683-
for req_idx in req_indices:
684-
for local_pos in range(query_lens[req_idx]):
685-
num_attend = min(cached_lens[req_idx] + local_pos + 1,
686-
topk_tokens)
687-
attend_indices = torch.arange(
688-
num_attend, dtype=torch.int32, device=device) + kv_offset
689-
if num_attend < topk_tokens:
690-
padding = torch.full((topk_tokens - num_attend, ),
691-
-1,
692-
dtype=torch.int32,
693-
device=device)
694-
attend_indices = torch.cat([attend_indices, padding])
695-
indices.append(attend_indices)
696-
kv_offset += seq_lens[req_idx]
697-
return torch.stack(indices, dim=0)
698-
699-
def local_to_global_indices(local_indices,
700-
req_indices,
701-
cache_offset_start=0):
702-
"""
703-
Transform indexer's local indices to global indices.
704-
"""
705-
global_indices = local_indices.clone()
706-
kv_offset = cache_offset_start
707-
token_idx = 0
708-
709-
for req_idx in req_indices:
710-
num_tokens = query_lens[req_idx]
711-
# Add offset for this request's cache position
712-
for local_pos in range(num_tokens):
713-
# Only transform non-padding indices (>= 0)
714-
mask = global_indices[token_idx] >= 0
715-
global_indices[token_idx][mask] += kv_offset
716-
token_idx += 1
717-
kv_offset += seq_lens[req_idx]
718-
return global_indices
719-
720-
topk_indices_local = mla.mqa.indexer(qr, hidden_states, attn_metadata,
721-
position_ids)
679+
topk_indices_local = mla.mqa.indexer(
680+
qr,
681+
hidden_states,
682+
attn_metadata,
683+
position_ids,
684+
None, # indexer_k
685+
None, # indexer_weights
686+
)
722687

723688
# Validate indexer output against expected causal indices (since seq_len < topk=2048)
724689
if num_contexts > 0:
725690
# Transform context indices from local to global
726691
ctx_topk_local = topk_indices_local[:num_ctx_tokens]
727692

728-
# Create expected global indices (sorted) for validation (not used but can be used for validation)
729-
expected_ctx_indices = create_causal_indices(ctx_indices,
730-
cache_offset_start=0)
731-
732693
mla.forward_context_dsa(
733694
q=q[:num_ctx_tokens],
734695
compressed_kv=compressed_kv[:num_ctx_tokens],
@@ -747,11 +708,6 @@ def local_to_global_indices(local_indices,
747708
num_gen_tokens = sum(query_lens[i] for i in gen_indices)
748709
gen_topk_local = topk_indices_local[num_ctx_tokens:num_ctx_tokens +
749710
num_gen_tokens]
750-
751-
# Create expected global indices (sorted) for validation (not used but can be used for validation)
752-
expected_gen_indices = create_causal_indices(gen_indices,
753-
cache_offset_start=0)
754-
755711
mla.forward_generation_dsa(
756712
q=q[num_ctx_tokens:],
757713
compressed_kv=compressed_kv[num_ctx_tokens:],

0 commit comments

Comments
 (0)