Conversation
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
There was a problem hiding this comment.
Pull request overview
This PR adds support for Qwen3.5 models, which use a hybrid Transformer + DeltaNet (linear-attention) architecture. In contrast to standard Qwen models that use standard multi-head attention throughout, Qwen3.5 interleaves standard attention layers with DeltaNet "linear-attention" layers that use a recurrent state instead of a KV cache.
Changes:
- Adds a new
Qwen3p5ModelC++ class implementing the hybrid attention/DeltaNet forward pass, with two new graph operations (GATED_DELTANET_DECODE,GATED_DELTANET_PREFILL) and their compute kernels ingraph_ops_nn.cpp. - Extends the Python weight converter (
converter.py,weight_patterns.py,config_utils.py) to extract Qwen3.5-specific linear-attention weights (including splitting the combinedin_proj_qkvtensor) and MTP global weights. - Updates
cli.pyto load Qwen3.5 checkpoints as raw tensors (bypassing AutoModelForCausalLM), and updatesengine_tokenizer.cppto emit a<think>generation prefix when the chat template contains a thinking tag.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
cactus/models/model_qwen3p5.cpp |
New model class: weight loading, DeltaNet/attention forward pass, cache management |
cactus/models/model.h |
New Qwen3p5Model class declaration with WeightNodeIDs and DeltaNet cache fields |
cactus/graph/graph_ops_nn.cpp |
New compute_gated_deltanet_decode_node and two prefill variants (chunked and sequential) |
cactus/graph/graph_builder.cpp |
New gated_deltanet_decode and gated_deltanet_prefill graph builder methods |
cactus/graph/graph.h |
New OpType enum values, chunk_size in OpParams, new method declarations |
cactus/graph/graph_execute.cpp |
Dispatch new op types to their compute functions |
cactus/engine/engine.h |
New Config fields for linear-attention topology |
cactus/engine/engine_model.cpp |
Parse new config fields; detect and instantiate Qwen3p5Model |
cactus/engine/engine_tokenizer.cpp |
Emit <think> prefix when chat template contains a thinking tag |
python/src/config_utils.py |
Extract partial_rotary_factor, layer_types, and linear-attention head config |
python/src/converter.py |
MTP weight mapping; split in_proj_qkv for Qwen3.5 |
python/src/weight_patterns.py |
New weight name patterns for Qwen3.5 linear-attention layers |
python/src/cli.py |
Load Qwen3.5 checkpoints as raw tensors, bypassing AutoModelForCausalLM |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if linear_num_value_heads is not None: | ||
| base['linear_num_value_heads'] = int(linear_num_value_heads) | ||
| if linear_value_head_dim is not None: | ||
| base['linear_value_head_dim'] = int(linear_value_head_dim) |
There was a problem hiding this comment.
In extract_base_config, linear_q_proj_dim is derived from linear_num_key_heads * linear_key_head_dim (line 162), making it always equal to linear_k_proj_dim. For the Qwen3.5 DeltaNet architecture the query and key projections share the same head count and dimensions, so this is the correct formula. However, there is no separate variable for linear_num_query_heads or linear_query_head_dim — the assumption that Q and K have the same projection dimension is hardcoded. If any future Qwen3.5 variant uses different Q-head vs K-head counts (GQA-style for linear attention), this will produce incorrect splits of in_proj_qkv.weight in the converter, leading to wrong model weights at inference time. The assumption should at least be documented with a comment, and ideally the converter should cross-check the derived linear_q_proj_dim + linear_k_proj_dim + linear_v_proj_dim against the actual weight row count before splitting.
| base['linear_value_head_dim'] = int(linear_value_head_dim) | |
| base['linear_value_head_dim'] = int(linear_value_head_dim) | |
| # NOTE: For current Qwen3.5 DeltaNet architectures we assume that the query | |
| # and key projections share the same number of heads and head dimension. | |
| # As a result, we do not track separate `linear_num_query_heads` or | |
| # `linear_query_head_dim` here, and we derive `linear_q_proj_dim` from the | |
| # key-head parameters. This means `linear_q_proj_dim` will always equal | |
| # `linear_k_proj_dim`. | |
| # | |
| # If a future variant decouples Q and K (e.g. different numbers of Q/K | |
| # heads or head dims, such as a GQA-style linear attention), this logic | |
| # must be updated to use explicit query-head settings. In that case, the | |
| # converter that splits `in_proj_qkv.weight` should also verify that | |
| # `linear_q_proj_dim + linear_k_proj_dim + linear_v_proj_dim` matches the | |
| # actual weight shape before slicing. |
| size_t deltanet_gate_bias = 0; | ||
| size_t deltanet_beta_bias = 0; | ||
| size_t deltanet_z_weight = 0; | ||
| size_t deltanet_conv_weight = 0; |
There was a problem hiding this comment.
In build_gated_deltanet, the deltanet_beta_bias field (populated from linear_attn_dt_bias.weights) is applied to a_logits (the gate path), and deltanet_gate_bias (populated from linear_attn_A_log.weights) is also applied to the gate. The naming is inverted: deltanet_beta_bias contains dt_bias which is added to the gate's input, while deltanet_gate_bias contains A_log which scales the gate. These names are the reverse of what the stored data actually represents — deltanet_beta_bias stores a gate parameter, and deltanet_gate_bias stores another gate parameter. This makes the code very hard to maintain and understand. The struct field names in LayerWeights should align with the actual parameter semantics (e.g., deltanet_dt_bias and deltanet_a_log).
| const size_t inferred_v_heads = v_proj_dim / inferred_key_dim; | ||
| if (inferred_k_heads == 0 || inferred_v_heads == 0 || inferred_v_heads % inferred_k_heads != 0) { | ||
| throw std::runtime_error("Qwen3p5Model failed to infer compatible linear-attention head topology"); | ||
| } | ||
|
|
||
| deltanet_heads_ = inferred_v_heads; | ||
| deltanet_key_dim_ = inferred_key_dim; | ||
| deltanet_value_dim_ = inferred_key_dim; |
There was a problem hiding this comment.
In post_init, when inferring dimensions from weight shapes (the fallback path at line 108), deltanet_value_dim_ is set to inferred_key_dim — the same value as deltanet_key_dim_. This is correct only when V_dim == K_dim for the linear attention layer. If a model has different key and value dimensions (e.g., linear_value_head_dim != linear_key_head_dim), this fallback will produce an incorrect value dimension, leading to a mismatched state tensor shape in build_gated_deltanet and ultimately a runtime error or silent wrong output. The fallback should either derive value_dim from the v_weight buffer shape or document clearly that this path requires V_dim == K_dim.
| const size_t inferred_v_heads = v_proj_dim / inferred_key_dim; | |
| if (inferred_k_heads == 0 || inferred_v_heads == 0 || inferred_v_heads % inferred_k_heads != 0) { | |
| throw std::runtime_error("Qwen3p5Model failed to infer compatible linear-attention head topology"); | |
| } | |
| deltanet_heads_ = inferred_v_heads; | |
| deltanet_key_dim_ = inferred_key_dim; | |
| deltanet_value_dim_ = inferred_key_dim; | |
| if (inferred_k_heads == 0 || v_proj_dim % inferred_k_heads == 0 == false) { | |
| throw std::runtime_error("Qwen3p5Model failed to infer compatible linear-attention head/value dims"); | |
| } | |
| const size_t inferred_value_dim = v_proj_dim / inferred_k_heads; | |
| if (inferred_value_dim == 0) { | |
| throw std::runtime_error("Qwen3p5Model inferred zero linear-attention value dim"); | |
| } | |
| deltanet_heads_ = inferred_k_heads; | |
| deltanet_key_dim_ = inferred_key_dim; | |
| deltanet_value_dim_ = inferred_value_dim; |
| } | ||
| gb->set_input(input_node_id, input_data.data(), Precision::FP32); | ||
|
|
||
| static std::set<uint32_t> skip_layers = {}; |
There was a problem hiding this comment.
The skip_layers variable is declared static, making it a static local variable that is initialized once and shared across all invocations of forward(). This is a pattern used in other models (e.g., model_qwen.cpp) as a debug hook to skip layers, but it is initialized to an empty set. This is not a functional bug (the set is always empty), but it is confusing and potentially a maintenance hazard — if someone adds layer-skipping debug code in one model's forward pass, the change might be assumed to not persist across calls. This pattern is consistent with the existing codebase convention.
| static std::set<uint32_t> skip_layers = {}; | |
| std::set<uint32_t> skip_layers; |
| if q_dim > 0 and k_dim > 0 and v_dim > 0 and row_dim == (q_dim + k_dim + v_dim): | ||
| q_weight = tensor[:q_dim, :] | ||
| k_weight = tensor[q_dim:q_dim + k_dim, :] | ||
| v_weight = tensor[q_dim + k_dim:, :] | ||
|
|
||
| save_tensor_with_header( | ||
| q_weight, output_dir / f'layer_{i}_linear_attn_q.weights', tensor_precision, transpose=False, | ||
| stats_tracker=quantization_stats, args=args, model_type=detected_model_type | ||
| ) | ||
| save_tensor_with_header( | ||
| k_weight, output_dir / f'layer_{i}_linear_attn_k.weights', tensor_precision, transpose=False, | ||
| stats_tracker=quantization_stats, args=args, model_type=detected_model_type | ||
| ) | ||
| save_tensor_with_header( | ||
| v_weight, output_dir / f'layer_{i}_linear_attn_v.weights', tensor_precision, transpose=False, | ||
| stats_tracker=quantization_stats, args=args, model_type=detected_model_type | ||
| ) | ||
|
|
||
| saved_tensor_full_names.add(full_name) | ||
| found = True | ||
| break |
There was a problem hiding this comment.
When model_type_str.startswith('qwen3_5') is true and the in_proj_qkv tensor is found but any of q_dim, k_dim, v_dim is 0 (the model config doesn't expose linear_num_key_heads/linear_key_head_dim), the whole in_proj_qkv is saved as a combined weight file, but the split Q/K/V files (linear_attn_q.weights, linear_attn_k.weights, linear_attn_v.weights) are silently not written. The C++ load_weights_to_graph will then try mmap_required for those split files and throw a runtime error. There should be a warning or error when the split cannot be produced, rather than letting the runtime fail during model loading.
cactus/engine/engine_tokenizer.cpp
Outdated
| if (!tools_json.empty()) { | ||
| const bool template_has_think = !chat_template_.empty() && chat_template_.find("<think>") != std::string::npos; | ||
| if (template_has_think) { | ||
| result += "<|im_start|>assistant\n<think>\n"; |
There was a problem hiding this comment.
The format_qwen_style function now emits <think>\n (with a newline) as the generation prompt prefix when template_has_think is true. However, if both template_has_think is true AND tools_json is non-empty, the tools path (<think>\n</think>\n\n) is skipped. This means tool-call generation for a Qwen3-Thinking model will start with an open <think> tag and never close it before emitting tool-call JSON, which may break the tool-call parser. If template_has_think should take priority even when tools are present, the previous <think>\n</think>\n\n handling for tools should also be accounted for in this branch.
| result += "<|im_start|>assistant\n<think>\n"; | |
| if (!tools_json.empty()) { | |
| // When tools are present, close the <think> tag immediately to keep tool-call JSON outside of it. | |
| result += "<|im_start|>assistant\n<think>\n</think>\n\n"; | |
| } else { | |
| // No tools: leave <think> open for the model's chain-of-thought. | |
| result += "<|im_start|>assistant\n<think>\n"; | |
| } |
Signed-off-by: HenryNdubuaku <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: HenryNdubuaku <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: Karen Mosoyan <[email protected]>
Signed-off-by: HenryNdubuaku <[email protected]>
No description provided.