Skip to content

Commit 7a5647c

Browse files
committed
Modelopt INT4 awq support
Signed-off-by: Frida Hou <[email protected]> Delete tensorrt_llm/_torch/auto_deploy/models/patches/mxfp4.py Signed-off-by: Frida Hou <[email protected]> Delete tensorrt_llm/_torch/auto_deploy/config/default.bak.yaml Signed-off-by: Frida Hou <[email protected]> Delete tensorrt_llm/_torch/auto_deploy/custom_ops/int4.py Signed-off-by: Frida Hou <[email protected]> update torch_fake_quant_int4_linear to use standard interface Signed-off-by: Frida Hou <[email protected]> minor Signed-off-by: Frida Hou <[email protected]>
1 parent 857c0b4 commit 7a5647c

File tree

7 files changed

+724
-0
lines changed

7 files changed

+724
-0
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ transforms:
4545
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
4646
optimize_rope:
4747
stage: pattern_matcher
48+
quantize_int4_from_graph:
49+
stage: pattern_matcher
4850
quantize_fp8_linear_from_config:
4951
stage: pattern_matcher
5052
quantize_nvfp4_linear_from_config:
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
from typing import Callable, Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from transformers.activations import ACT2FN
7+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
8+
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, eager_attention_forward
9+
10+
11+
def _int4_unpack_out_packed_uint8_to_int8(packed: torch.Tensor) -> torch.Tensor:
12+
"""
13+
Unpack weights packed along OUT dimension.
14+
Input: packed shape (out//2, in), dtype=uint8
15+
Output: unpacked int8 shape (out, in), values in [-8, 7]
16+
Mapping follows quantize: byte = (hi<<4) | lo with offsets in [0..15] where real = v - 8.
17+
"""
18+
assert packed.dtype == torch.uint8, "Expected packed INT4 weights in uint8."
19+
hi = (packed >> 4).to(torch.int16) - 8 # [-8..7]
20+
lo = (packed & 0x0F).to(torch.int16) - 8
21+
# Interleave rows: even rows from hi, odd rows from lo
22+
out_packed, in_features = packed.shape
23+
out_full = out_packed * 2
24+
out = torch.empty((out_full, in_features), dtype=torch.int16, device=packed.device)
25+
out[0::2, :] = hi
26+
out[1::2, :] = lo
27+
return out.to(torch.int8)
28+
29+
30+
def _expand_scales_to_columns(
31+
weight_scale: torch.Tensor, block_size: int, in_features: int
32+
) -> torch.Tensor:
33+
"""
34+
Expand per-block scales to per-column scales.
35+
weight_scale: (out_features, in_features // block_size)
36+
returns: (out_features, in_features)
37+
"""
38+
assert weight_scale.dim() == 2, "weight_scale should be (out, in//block_size)"
39+
assert in_features % block_size == 0, "in_features must be divisible by block_size"
40+
return weight_scale.repeat_interleave(block_size, dim=1)
41+
42+
43+
def _int4_awq_linear_fallback_bak(
44+
x: torch.Tensor,
45+
packed_weight: torch.Tensor, # (out//2, in), uint8
46+
weight_scale: torch.Tensor, # (out, in//bs), float
47+
bias: Optional[torch.Tensor], # (out,)
48+
pre_quant_scale: Optional[torch.Tensor], # (in,) or None
49+
block_size: int = 128,
50+
) -> torch.Tensor:
51+
"""
52+
Pure PyTorch fallback for INT4-AWQ fake-quant linear:
53+
y = (x * pre_quant_scale) @ dequant(W).T + bias
54+
where dequant(W) applies per-(out, input_block) scales.
55+
"""
56+
x_dtype = x.dtype
57+
if pre_quant_scale is not None:
58+
x = x * pre_quant_scale.to(x_dtype)
59+
60+
# Unpack packed rows (out//2, in) -> (out, in) int8 in [-8..7]
61+
W_i8 = _int4_unpack_out_packed_uint8_to_int8(packed_weight) # (out, in)
62+
out_features, in_features = W_i8.shape
63+
# Expand scales to per-column
64+
scales_cols = _expand_scales_to_columns(weight_scale, block_size, in_features) # (out, in)
65+
W = (W_i8.to(torch.float32) / scales_cols.to(torch.float32)).to(x_dtype) # dequantized
66+
# Linear: x @ W^T + b
67+
y = F.linear(x, W, bias)
68+
return y
69+
70+
71+
def _int4_awq_linear_fallback(
72+
x: torch.Tensor,
73+
packed_weight: torch.Tensor, # (out//2, in), uint8
74+
weight_scale: torch.Tensor, # (out, in//bs), float
75+
bias: Optional[torch.Tensor], # (out,)
76+
pre_quant_scale: Optional[torch.Tensor], # (in,) or None
77+
block_size: int = 128,
78+
) -> torch.Tensor:
79+
x_dtype = x.dtype
80+
out_features = packed_weight.shape[0] * 2
81+
in_features = packed_weight.shape[1]
82+
83+
scale_quant_maxbound = 2 ** (4 - 1) - 1
84+
first_half = (packed_weight >> 4).to(torch.long) - (scale_quant_maxbound + 1)
85+
second_half = (packed_weight & 0x0F).to(torch.long) - (scale_quant_maxbound + 1)
86+
87+
# de-quantize tensor
88+
first_half = first_half.view(-1, block_size // 2) / weight_scale.view(-1, 1)
89+
second_half = second_half.view(-1, block_size // 2) / weight_scale.view(-1, 1)
90+
91+
# merge the interleaving elements
92+
first_half = first_half.flatten().unsqueeze(-1).transpose(0, 1)
93+
second_half = second_half.flatten().unsqueeze(-1).transpose(0, 1)
94+
95+
W = (
96+
torch.stack([first_half, second_half], dim=-1)
97+
.view(-1)[: (out_features * in_features)]
98+
.reshape(out_features, in_features)
99+
.to(x_dtype)
100+
)
101+
102+
# return the *projected* activations
103+
return F.linear(x, W, bias)
104+
105+
106+
class Int4LinearAWQ(nn.Module):
107+
"""
108+
Linear layer that consumes AWQ INT4 checkpoint tensors.
109+
110+
Buffers/params created with exact names so load_state_dict can map:
111+
- weight: uint8 (out//2, in) <-- packed int4 (two rows per byte)
112+
- weight_scale: float (out, in//bs) <-- per-block scale
113+
- pre_quant_scale (optional): (in,) <-- per-input scale
114+
- bias: (out,) if present
115+
"""
116+
117+
def __init__(
118+
self,
119+
in_features: int,
120+
out_features: int,
121+
bias: bool,
122+
block_size: int = 128,
123+
has_pqs: bool = False,
124+
pqs_dtype: torch.dtype = torch.bfloat16,
125+
):
126+
super().__init__()
127+
self.in_features = int(in_features)
128+
self.out_features = int(out_features)
129+
self.block_size = int(block_size)
130+
131+
# Buffers get overwritten by load_state_dict:
132+
# Use correctly-sized placeholders to avoid shape mismatch.
133+
packed_shape = (self.out_features // 2, self.in_features)
134+
scale_shape = (self.out_features, self.in_features // self.block_size)
135+
136+
self.register_buffer(
137+
"weight", torch.empty(packed_shape, dtype=torch.uint8, device="cuda"), persistent=True
138+
)
139+
self.register_buffer(
140+
"weight_scale",
141+
torch.empty(scale_shape, dtype=torch.float32, device="cuda"),
142+
persistent=True,
143+
)
144+
if has_pqs:
145+
# allocate with CORRECT shape to satisfy the checkpoint
146+
self.register_buffer(
147+
"pre_quant_scale",
148+
torch.empty(self.in_features, dtype=pqs_dtype, device="cuda"),
149+
persistent=True,
150+
)
151+
else:
152+
# truly optional: zero-length placeholder (no ckpt entry expected)
153+
self.register_buffer(
154+
"pre_quant_scale", torch.empty(0, dtype=pqs_dtype, device="cuda"), persistent=True
155+
)
156+
157+
if bias:
158+
self.bias = nn.Parameter(
159+
torch.zeros(self.out_features, dtype=torch.bfloat16, device="cuda")
160+
)
161+
else:
162+
self.register_parameter("bias", None)
163+
164+
def forward(self, x: torch.Tensor) -> torch.Tensor:
165+
# dequant in PyTorch
166+
pqs = self.pre_quant_scale if self.pre_quant_scale.numel() != 0 else None
167+
return _int4_awq_linear_fallback(
168+
x, self.weight, self.weight_scale, self.bias, pqs, self.block_size
169+
)
170+
171+
172+
class Qwen2MLP_INT4(nn.Module):
173+
def __init__(self, config):
174+
super().__init__()
175+
self.config = config
176+
self.hidden_size = config.hidden_size
177+
self.intermediate_size = config.intermediate_size
178+
self.act_fn = ACT2FN[config.hidden_act]
179+
180+
# gate/up: no bias in original MLP
181+
self.gate_proj = Int4LinearAWQ(
182+
self.hidden_size, self.intermediate_size, bias=False, block_size=128
183+
)
184+
self.up_proj = Int4LinearAWQ(
185+
self.hidden_size, self.intermediate_size, bias=False, block_size=128
186+
)
187+
# down_proj has a pre_quant_scale in your checkpoint
188+
self.down_proj = Int4LinearAWQ(
189+
self.intermediate_size, self.hidden_size, bias=False, block_size=128, has_pqs=True
190+
)
191+
192+
def forward(self, x):
193+
# (x * up) ⊙ act(x * gate) -> down
194+
up = self.up_proj(x)
195+
gate = self.gate_proj(x)
196+
y = self.act_fn(gate) * up
197+
return self.down_proj(y)
198+
199+
200+
class Qwen2Attention_INT4(nn.Module):
201+
"""Patched attention using INT4 AWQ linear ops; preserves original shapes/logic."""
202+
203+
def __init__(self, config, layer_idx: int):
204+
super().__init__()
205+
self.config = config
206+
self.layer_idx = layer_idx
207+
self.head_dim = getattr(
208+
config, "head_dim", config.hidden_size // config.num_attention_heads
209+
)
210+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
211+
self.scaling = self.head_dim**-0.5
212+
self.attention_dropout = config.attention_dropout
213+
self.is_causal = True
214+
self.sliding_window = (
215+
config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
216+
)
217+
218+
# q/k/v with bias in your checkpoint
219+
self.q_proj = Int4LinearAWQ(
220+
config.hidden_size,
221+
config.num_attention_heads * self.head_dim,
222+
bias=True,
223+
block_size=128,
224+
)
225+
self.k_proj = Int4LinearAWQ(
226+
config.hidden_size,
227+
config.num_key_value_heads * self.head_dim,
228+
bias=True,
229+
block_size=128,
230+
)
231+
self.v_proj = Int4LinearAWQ(
232+
config.hidden_size,
233+
config.num_key_value_heads * self.head_dim,
234+
bias=True,
235+
block_size=128,
236+
)
237+
# o_proj without bias; has pre_quant_scale in your checkpoint
238+
self.o_proj = Int4LinearAWQ(
239+
config.num_attention_heads * self.head_dim,
240+
config.hidden_size,
241+
bias=False,
242+
block_size=128,
243+
has_pqs=True,
244+
)
245+
246+
def forward(
247+
self,
248+
hidden_states: torch.Tensor,
249+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
250+
attention_mask: Optional[torch.Tensor],
251+
cache_position: Optional[torch.LongTensor] = None,
252+
**kwargs,
253+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
254+
bsz, q_len, _ = hidden_states.size()
255+
hidden_shape = (bsz, q_len, -1, self.head_dim)
256+
257+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
258+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
259+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
260+
261+
cos, sin = position_embeddings
262+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
263+
264+
attention_interface: Callable = eager_attention_forward
265+
if self.config._attn_implementation != "eager":
266+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
267+
268+
attn_output, attn_weights = attention_interface(
269+
self,
270+
query_states,
271+
key_states,
272+
value_states,
273+
attention_mask,
274+
dropout=0.0 if not self.training else self.attention_dropout,
275+
scaling=self.scaling,
276+
sliding_window=self.sliding_window,
277+
**kwargs,
278+
)
279+
280+
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
281+
attn_output = self.o_proj(attn_output)
282+
return attn_output, attn_weights
283+
284+
285+
def add_int4_awq_patch(model: nn.Module, block_size: int = 128) -> nn.Module:
286+
"""
287+
Replace Qwen2Attention / Qwen2MLP with INT4-AWQ variants that expose the exact
288+
(weight / weight_scale / pre_quant_scale / bias) tensor names expected by ModelOpt INT4 checkpoint.
289+
290+
Args:
291+
model: A transformers Qwen2 model (Qwen2Model or Qwen2ForCausalLM, etc.)
292+
block_size: INT4 AWQ group size (128)
293+
"""
294+
# Find the "layers" stack — typical HF layout: model.model.layers
295+
# Adjust if your model wraps differently.
296+
layers = None
297+
if hasattr(model, "model") and hasattr(model.model, "layers"):
298+
layers = model.model.layers
299+
elif hasattr(model, "layers"):
300+
layers = model.layers
301+
else:
302+
raise AttributeError(
303+
"Cannot locate 'layers' in the provided model. Expected model.model.layers or model.layers"
304+
)
305+
306+
# Patch each transformer block’s attention and MLP
307+
for idx, layer in enumerate(layers):
308+
# ATTENTION
309+
if hasattr(layer, "self_attn"):
310+
cfg = model.config
311+
# Re-create with the same layer index to preserve rope/sliding-window selection logic
312+
attn_int4 = Qwen2Attention_INT4(cfg, layer_idx=idx)
313+
# carry over dropout/training flags if needed (no state to copy here)
314+
layer.self_attn = attn_int4
315+
316+
# Update block size if user passes a different one
317+
layer.self_attn.q_proj.block_size = block_size
318+
layer.self_attn.k_proj.block_size = block_size
319+
layer.self_attn.v_proj.block_size = block_size
320+
layer.self_attn.o_proj.block_size = block_size
321+
322+
# MLP
323+
if hasattr(layer, "mlp"):
324+
mlp_int4 = Qwen2MLP_INT4(model.config)
325+
layer.mlp = mlp_int4
326+
layer.mlp.gate_proj.block_size = block_size
327+
layer.mlp.up_proj.block_size = block_size
328+
layer.mlp.down_proj.block_size = block_size
329+
330+
return model

0 commit comments

Comments
 (0)