|
| 1 | +from typing import Callable, Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +from transformers.activations import ACT2FN |
| 7 | +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| 8 | +from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, eager_attention_forward |
| 9 | + |
| 10 | + |
| 11 | +def _int4_unpack_out_packed_uint8_to_int8(packed: torch.Tensor) -> torch.Tensor: |
| 12 | + """ |
| 13 | + Unpack weights packed along OUT dimension. |
| 14 | + Input: packed shape (out//2, in), dtype=uint8 |
| 15 | + Output: unpacked int8 shape (out, in), values in [-8, 7] |
| 16 | + Mapping follows quantize: byte = (hi<<4) | lo with offsets in [0..15] where real = v - 8. |
| 17 | + """ |
| 18 | + assert packed.dtype == torch.uint8, "Expected packed INT4 weights in uint8." |
| 19 | + hi = (packed >> 4).to(torch.int16) - 8 # [-8..7] |
| 20 | + lo = (packed & 0x0F).to(torch.int16) - 8 |
| 21 | + # Interleave rows: even rows from hi, odd rows from lo |
| 22 | + out_packed, in_features = packed.shape |
| 23 | + out_full = out_packed * 2 |
| 24 | + out = torch.empty((out_full, in_features), dtype=torch.int16, device=packed.device) |
| 25 | + out[0::2, :] = hi |
| 26 | + out[1::2, :] = lo |
| 27 | + return out.to(torch.int8) |
| 28 | + |
| 29 | + |
| 30 | +def _expand_scales_to_columns( |
| 31 | + weight_scale: torch.Tensor, block_size: int, in_features: int |
| 32 | +) -> torch.Tensor: |
| 33 | + """ |
| 34 | + Expand per-block scales to per-column scales. |
| 35 | + weight_scale: (out_features, in_features // block_size) |
| 36 | + returns: (out_features, in_features) |
| 37 | + """ |
| 38 | + assert weight_scale.dim() == 2, "weight_scale should be (out, in//block_size)" |
| 39 | + assert in_features % block_size == 0, "in_features must be divisible by block_size" |
| 40 | + return weight_scale.repeat_interleave(block_size, dim=1) |
| 41 | + |
| 42 | + |
| 43 | +def _int4_awq_linear_fallback_bak( |
| 44 | + x: torch.Tensor, |
| 45 | + packed_weight: torch.Tensor, # (out//2, in), uint8 |
| 46 | + weight_scale: torch.Tensor, # (out, in//bs), float |
| 47 | + bias: Optional[torch.Tensor], # (out,) |
| 48 | + pre_quant_scale: Optional[torch.Tensor], # (in,) or None |
| 49 | + block_size: int = 128, |
| 50 | +) -> torch.Tensor: |
| 51 | + """ |
| 52 | + Pure PyTorch fallback for INT4-AWQ fake-quant linear: |
| 53 | + y = (x * pre_quant_scale) @ dequant(W).T + bias |
| 54 | + where dequant(W) applies per-(out, input_block) scales. |
| 55 | + """ |
| 56 | + x_dtype = x.dtype |
| 57 | + if pre_quant_scale is not None: |
| 58 | + x = x * pre_quant_scale.to(x_dtype) |
| 59 | + |
| 60 | + # Unpack packed rows (out//2, in) -> (out, in) int8 in [-8..7] |
| 61 | + W_i8 = _int4_unpack_out_packed_uint8_to_int8(packed_weight) # (out, in) |
| 62 | + out_features, in_features = W_i8.shape |
| 63 | + # Expand scales to per-column |
| 64 | + scales_cols = _expand_scales_to_columns(weight_scale, block_size, in_features) # (out, in) |
| 65 | + W = (W_i8.to(torch.float32) / scales_cols.to(torch.float32)).to(x_dtype) # dequantized |
| 66 | + # Linear: x @ W^T + b |
| 67 | + y = F.linear(x, W, bias) |
| 68 | + return y |
| 69 | + |
| 70 | + |
| 71 | +def _int4_awq_linear_fallback( |
| 72 | + x: torch.Tensor, |
| 73 | + packed_weight: torch.Tensor, # (out//2, in), uint8 |
| 74 | + weight_scale: torch.Tensor, # (out, in//bs), float |
| 75 | + bias: Optional[torch.Tensor], # (out,) |
| 76 | + pre_quant_scale: Optional[torch.Tensor], # (in,) or None |
| 77 | + block_size: int = 128, |
| 78 | +) -> torch.Tensor: |
| 79 | + x_dtype = x.dtype |
| 80 | + out_features = packed_weight.shape[0] * 2 |
| 81 | + in_features = packed_weight.shape[1] |
| 82 | + |
| 83 | + scale_quant_maxbound = 2 ** (4 - 1) - 1 |
| 84 | + first_half = (packed_weight >> 4).to(torch.long) - (scale_quant_maxbound + 1) |
| 85 | + second_half = (packed_weight & 0x0F).to(torch.long) - (scale_quant_maxbound + 1) |
| 86 | + |
| 87 | + # de-quantize tensor |
| 88 | + first_half = first_half.view(-1, block_size // 2) / weight_scale.view(-1, 1) |
| 89 | + second_half = second_half.view(-1, block_size // 2) / weight_scale.view(-1, 1) |
| 90 | + |
| 91 | + # merge the interleaving elements |
| 92 | + first_half = first_half.flatten().unsqueeze(-1).transpose(0, 1) |
| 93 | + second_half = second_half.flatten().unsqueeze(-1).transpose(0, 1) |
| 94 | + |
| 95 | + W = ( |
| 96 | + torch.stack([first_half, second_half], dim=-1) |
| 97 | + .view(-1)[: (out_features * in_features)] |
| 98 | + .reshape(out_features, in_features) |
| 99 | + .to(x_dtype) |
| 100 | + ) |
| 101 | + |
| 102 | + # return the *projected* activations |
| 103 | + return F.linear(x, W, bias) |
| 104 | + |
| 105 | + |
| 106 | +class Int4LinearAWQ(nn.Module): |
| 107 | + """ |
| 108 | + Linear layer that consumes AWQ INT4 checkpoint tensors. |
| 109 | + |
| 110 | + Buffers/params created with exact names so load_state_dict can map: |
| 111 | + - weight: uint8 (out//2, in) <-- packed int4 (two rows per byte) |
| 112 | + - weight_scale: float (out, in//bs) <-- per-block scale |
| 113 | + - pre_quant_scale (optional): (in,) <-- per-input scale |
| 114 | + - bias: (out,) if present |
| 115 | + """ |
| 116 | + |
| 117 | + def __init__( |
| 118 | + self, |
| 119 | + in_features: int, |
| 120 | + out_features: int, |
| 121 | + bias: bool, |
| 122 | + block_size: int = 128, |
| 123 | + has_pqs: bool = False, |
| 124 | + pqs_dtype: torch.dtype = torch.bfloat16, |
| 125 | + ): |
| 126 | + super().__init__() |
| 127 | + self.in_features = int(in_features) |
| 128 | + self.out_features = int(out_features) |
| 129 | + self.block_size = int(block_size) |
| 130 | + |
| 131 | + # Buffers get overwritten by load_state_dict: |
| 132 | + # Use correctly-sized placeholders to avoid shape mismatch. |
| 133 | + packed_shape = (self.out_features // 2, self.in_features) |
| 134 | + scale_shape = (self.out_features, self.in_features // self.block_size) |
| 135 | + |
| 136 | + self.register_buffer( |
| 137 | + "weight", torch.empty(packed_shape, dtype=torch.uint8, device="cuda"), persistent=True |
| 138 | + ) |
| 139 | + self.register_buffer( |
| 140 | + "weight_scale", |
| 141 | + torch.empty(scale_shape, dtype=torch.float32, device="cuda"), |
| 142 | + persistent=True, |
| 143 | + ) |
| 144 | + if has_pqs: |
| 145 | + # allocate with CORRECT shape to satisfy the checkpoint |
| 146 | + self.register_buffer( |
| 147 | + "pre_quant_scale", |
| 148 | + torch.empty(self.in_features, dtype=pqs_dtype, device="cuda"), |
| 149 | + persistent=True, |
| 150 | + ) |
| 151 | + else: |
| 152 | + # truly optional: zero-length placeholder (no ckpt entry expected) |
| 153 | + self.register_buffer( |
| 154 | + "pre_quant_scale", torch.empty(0, dtype=pqs_dtype, device="cuda"), persistent=True |
| 155 | + ) |
| 156 | + |
| 157 | + if bias: |
| 158 | + self.bias = nn.Parameter( |
| 159 | + torch.zeros(self.out_features, dtype=torch.bfloat16, device="cuda") |
| 160 | + ) |
| 161 | + else: |
| 162 | + self.register_parameter("bias", None) |
| 163 | + |
| 164 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 165 | + # dequant in PyTorch |
| 166 | + pqs = self.pre_quant_scale if self.pre_quant_scale.numel() != 0 else None |
| 167 | + return _int4_awq_linear_fallback( |
| 168 | + x, self.weight, self.weight_scale, self.bias, pqs, self.block_size |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +class Qwen2MLP_INT4(nn.Module): |
| 173 | + def __init__(self, config): |
| 174 | + super().__init__() |
| 175 | + self.config = config |
| 176 | + self.hidden_size = config.hidden_size |
| 177 | + self.intermediate_size = config.intermediate_size |
| 178 | + self.act_fn = ACT2FN[config.hidden_act] |
| 179 | + |
| 180 | + # gate/up: no bias in original MLP |
| 181 | + self.gate_proj = Int4LinearAWQ( |
| 182 | + self.hidden_size, self.intermediate_size, bias=False, block_size=128 |
| 183 | + ) |
| 184 | + self.up_proj = Int4LinearAWQ( |
| 185 | + self.hidden_size, self.intermediate_size, bias=False, block_size=128 |
| 186 | + ) |
| 187 | + # down_proj has a pre_quant_scale in your checkpoint |
| 188 | + self.down_proj = Int4LinearAWQ( |
| 189 | + self.intermediate_size, self.hidden_size, bias=False, block_size=128, has_pqs=True |
| 190 | + ) |
| 191 | + |
| 192 | + def forward(self, x): |
| 193 | + # (x * up) ⊙ act(x * gate) -> down |
| 194 | + up = self.up_proj(x) |
| 195 | + gate = self.gate_proj(x) |
| 196 | + y = self.act_fn(gate) * up |
| 197 | + return self.down_proj(y) |
| 198 | + |
| 199 | + |
| 200 | +class Qwen2Attention_INT4(nn.Module): |
| 201 | + """Patched attention using INT4 AWQ linear ops; preserves original shapes/logic.""" |
| 202 | + |
| 203 | + def __init__(self, config, layer_idx: int): |
| 204 | + super().__init__() |
| 205 | + self.config = config |
| 206 | + self.layer_idx = layer_idx |
| 207 | + self.head_dim = getattr( |
| 208 | + config, "head_dim", config.hidden_size // config.num_attention_heads |
| 209 | + ) |
| 210 | + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| 211 | + self.scaling = self.head_dim**-0.5 |
| 212 | + self.attention_dropout = config.attention_dropout |
| 213 | + self.is_causal = True |
| 214 | + self.sliding_window = ( |
| 215 | + config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None |
| 216 | + ) |
| 217 | + |
| 218 | + # q/k/v with bias in your checkpoint |
| 219 | + self.q_proj = Int4LinearAWQ( |
| 220 | + config.hidden_size, |
| 221 | + config.num_attention_heads * self.head_dim, |
| 222 | + bias=True, |
| 223 | + block_size=128, |
| 224 | + ) |
| 225 | + self.k_proj = Int4LinearAWQ( |
| 226 | + config.hidden_size, |
| 227 | + config.num_key_value_heads * self.head_dim, |
| 228 | + bias=True, |
| 229 | + block_size=128, |
| 230 | + ) |
| 231 | + self.v_proj = Int4LinearAWQ( |
| 232 | + config.hidden_size, |
| 233 | + config.num_key_value_heads * self.head_dim, |
| 234 | + bias=True, |
| 235 | + block_size=128, |
| 236 | + ) |
| 237 | + # o_proj without bias; has pre_quant_scale in your checkpoint |
| 238 | + self.o_proj = Int4LinearAWQ( |
| 239 | + config.num_attention_heads * self.head_dim, |
| 240 | + config.hidden_size, |
| 241 | + bias=False, |
| 242 | + block_size=128, |
| 243 | + has_pqs=True, |
| 244 | + ) |
| 245 | + |
| 246 | + def forward( |
| 247 | + self, |
| 248 | + hidden_states: torch.Tensor, |
| 249 | + position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| 250 | + attention_mask: Optional[torch.Tensor], |
| 251 | + cache_position: Optional[torch.LongTensor] = None, |
| 252 | + **kwargs, |
| 253 | + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| 254 | + bsz, q_len, _ = hidden_states.size() |
| 255 | + hidden_shape = (bsz, q_len, -1, self.head_dim) |
| 256 | + |
| 257 | + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| 258 | + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| 259 | + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| 260 | + |
| 261 | + cos, sin = position_embeddings |
| 262 | + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| 263 | + |
| 264 | + attention_interface: Callable = eager_attention_forward |
| 265 | + if self.config._attn_implementation != "eager": |
| 266 | + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
| 267 | + |
| 268 | + attn_output, attn_weights = attention_interface( |
| 269 | + self, |
| 270 | + query_states, |
| 271 | + key_states, |
| 272 | + value_states, |
| 273 | + attention_mask, |
| 274 | + dropout=0.0 if not self.training else self.attention_dropout, |
| 275 | + scaling=self.scaling, |
| 276 | + sliding_window=self.sliding_window, |
| 277 | + **kwargs, |
| 278 | + ) |
| 279 | + |
| 280 | + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
| 281 | + attn_output = self.o_proj(attn_output) |
| 282 | + return attn_output, attn_weights |
| 283 | + |
| 284 | + |
| 285 | +def add_int4_awq_patch(model: nn.Module, block_size: int = 128) -> nn.Module: |
| 286 | + """ |
| 287 | + Replace Qwen2Attention / Qwen2MLP with INT4-AWQ variants that expose the exact |
| 288 | + (weight / weight_scale / pre_quant_scale / bias) tensor names expected by ModelOpt INT4 checkpoint. |
| 289 | + |
| 290 | + Args: |
| 291 | + model: A transformers Qwen2 model (Qwen2Model or Qwen2ForCausalLM, etc.) |
| 292 | + block_size: INT4 AWQ group size (128) |
| 293 | + """ |
| 294 | + # Find the "layers" stack — typical HF layout: model.model.layers |
| 295 | + # Adjust if your model wraps differently. |
| 296 | + layers = None |
| 297 | + if hasattr(model, "model") and hasattr(model.model, "layers"): |
| 298 | + layers = model.model.layers |
| 299 | + elif hasattr(model, "layers"): |
| 300 | + layers = model.layers |
| 301 | + else: |
| 302 | + raise AttributeError( |
| 303 | + "Cannot locate 'layers' in the provided model. Expected model.model.layers or model.layers" |
| 304 | + ) |
| 305 | + |
| 306 | + # Patch each transformer block’s attention and MLP |
| 307 | + for idx, layer in enumerate(layers): |
| 308 | + # ATTENTION |
| 309 | + if hasattr(layer, "self_attn"): |
| 310 | + cfg = model.config |
| 311 | + # Re-create with the same layer index to preserve rope/sliding-window selection logic |
| 312 | + attn_int4 = Qwen2Attention_INT4(cfg, layer_idx=idx) |
| 313 | + # carry over dropout/training flags if needed (no state to copy here) |
| 314 | + layer.self_attn = attn_int4 |
| 315 | + |
| 316 | + # Update block size if user passes a different one |
| 317 | + layer.self_attn.q_proj.block_size = block_size |
| 318 | + layer.self_attn.k_proj.block_size = block_size |
| 319 | + layer.self_attn.v_proj.block_size = block_size |
| 320 | + layer.self_attn.o_proj.block_size = block_size |
| 321 | + |
| 322 | + # MLP |
| 323 | + if hasattr(layer, "mlp"): |
| 324 | + mlp_int4 = Qwen2MLP_INT4(model.config) |
| 325 | + layer.mlp = mlp_int4 |
| 326 | + layer.mlp.gate_proj.block_size = block_size |
| 327 | + layer.mlp.up_proj.block_size = block_size |
| 328 | + layer.mlp.down_proj.block_size = block_size |
| 329 | + |
| 330 | + return model |
0 commit comments