Skip to content

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Sep 18, 2025

EDIT: README FIRST
This is an implementation of a new type of attention gating in GGML.
Therefore, this implementation will be focused on CORRECTNESS ONLY.
Speed tuning and support for more architectures will come in future PRs.
Please do not spam this threads with reports about performance, especially on backend architectures (CUDA, Vulkan).

CURRENT STATE: core is done

===
It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.

Resolves #15940

@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Sep 18, 2025
@gabe-l-hart
Copy link
Collaborator

I'll try to get into it in more detail soon, but here are a few general thoughts after quickly skimming the PR:

  1. The structure of what you've got smells correct, so it's likely close, but missing something small yet critical
  2. A full repro with the error it's raising would definitely help debug
  3. My debugging process for this would be:
    1. Make sure tokenization is solid (print statements as necessary to compare tokens before input)
    2. Use llama-eval-callback to dump tensors for a single prefill step
    3. Run an identical single prefill with the reference impl (transformers or otherwise), and inject prints as needed to dump tensors along the way
    4. Visually comb through them (particularly the sum at each point) to see where things start diverging significantly

@bugparty
Copy link
Contributor

It's been a real learning experience, not gonna lie, but if someone with hybrid model implementation experience (@gabe-l-hart ?) has some quick tips, I'd be grateful.

Currently at the stage of "graph builds, but first decode complains about wrong memory model", probably not building the inputs correctly.

Resolves #15940

interesting, maybe we can learn together

@pwilkin pwilkin marked this pull request as draft September 19, 2025 08:07
@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

  1. A full repro with the error it's raising would definitely help debug

Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:

#0  __syscall_cancel_arch () at ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S:56
56      in ../sysdeps/unix/sysv/linux/x86_64/syscall_cancel.S
#1  0x000070552b29eb63 in __internal_syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=0, a6=0, nr=61) at ./nptl/cancellation.c:49
warning: 49     ./nptl/cancellation.c: No such file or directory
#2  __syscall_cancel (a1=<optimized out>, a2=<optimized out>, a3=<optimized out>, a4=<optimized out>, a5=a5@entry=0, a6=a6@entry=0, nr=61) at ./nptl/cancellation.c:75
75      in ./nptl/cancellation.c
#3  0x000070552b31afdf in __GI___wait4 (pid=<optimized out>, stat_loc=<optimized out>, options=<optimized out>, usage=<optimized out>) at ../sysdeps/unix/sysv/linux/wait4.c:30
warning: 30     ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory
#4  0x000070552bb45c31 in ggml_print_backtrace () at /devel/tools/llama.cpp/ggml/src/ggml.c:196
warning: Source file is more recent than executable.
196             waitpid(child_pid, NULL, 0);
#5  0x000070552bb45de5 in ggml_abort (file=0x70552bbcdac8 "/devel/tools/llama.cpp/ggml/src/ggml-backend.cpp", line=189, fmt=0x70552bbcd8af "GGML_ASSERT(%s) failed") at /devel/tools/llama.cpp/ggml/src/ggml.c:230
230             ggml_print_backtrace();
#6  0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189         GGML_ASSERT(buffer);
#7  0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170         return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
#8  0x000070552c07a114 in llm_graph_input_rs::set_input (this=0x5f11bdf6aea0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:241
241             GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
#9  0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437         inp_rs->set_input(ubatch);
#10 0x000070552c07b549 in llm_graph_result::set_inputs (this=0x5f11be01ddf0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:480
480             input->set_input(ubatch);
#11 0x000070552c01ddb3 in llama_context::process_ubatch (this=0x5f11c05b5b50, ubatch=..., gtype=LLM_GRAPH_TYPE_DECODER, mctx=0x5f11be00ff00, ret=@0x7fff74d22ea4: 538976288) at /devel/tools/llama.cpp/src/llama-context.cpp:779
779             res->set_inputs(&ubatch);
#12 0x000070552c01f367 in llama_context::decode (this=0x5f11c05b5b50, batch_inp=...) at /devel/tools/llama.cpp/src/llama-context.cpp:1088
1088            const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
#13 0x000070552c025e49 in llama_decode (ctx=0x5f11c05b5b50, batch=...) at /devel/tools/llama.cpp/src/llama-context.cpp:2726
2726        const int ret = ctx->decode(batch);
#14 0x00005f11a2021559 in common_init_from_params (params=...) at /devel/tools/llama.cpp/common/common.cpp:1066
1066                llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
#15 0x00005f11a1e4a3c0 in main (argc=7, argv=0x7fff74d25968) at /devel/tools/llama.cpp/tools/main/main.cpp:140
140         common_init_result llama_init = common_init_from_params(params);

I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.

@CISC
Copy link
Collaborator

CISC commented Sep 19, 2025

  1. A full repro with the error it's raising would definitely help debug

Running llama-cli -m reference/qwen3_next_500m/Qwen3_Next_500M-8x417M-BF16.gguf -ngl 999 -p "Who are " yields this weird memory error:

...
#6  0x000070552bb6091e in ggml_backend_buffer_get_type (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:189
189         GGML_ASSERT(buffer);
#7  0x000070552bb6080e in ggml_backend_buffer_is_host (buffer=0x0) at /devel/tools/llama.cpp/ggml/src/ggml-backend.cpp:170
170         return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
...

The backend buffer is NULL.

@ngxson
Copy link
Collaborator

ngxson commented Sep 19, 2025

#9  0x000070552c07b03c in llm_graph_input_mem_hybrid::set_input (this=0x5f11bdf6aee0, ubatch=0x5f11be011300) at /devel/tools/llama.cpp/src/llama-graph.cpp:437
437         inp_rs->set_input(ubatch);

The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.

I'll try to merge the op into the ggml_delta_net function call as @ngxson suggested.

Hmm I think I said the reverse: not to merge it but make the op simple

I feel like this op can be implemented using other ggml ops like mul, mul_mat, sum. Which part of the calculation do you think that can't be constructed using existing ops?

This is the more important question: should we try to implement it using existing ops, or add a new op and spend even more time to optimize it cross all backends?

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

Now this is an error I haven't expected to encounter:

GGML_ABORT("not enough space in the context's memory pool");

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

The model doesn't seem to have any recurrence layers. This makes the set input fails due to input node not being present in cgraph.

How do I allocate the memory for the linear layers then? I seem to have misunderstood how build_inp_mem_hybrid() works...

@yarikdevcom
Copy link

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 19, 2025

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)

@ServeurpersoCom
Copy link
Collaborator

ServeurpersoCom commented Sep 19, 2025

@pwilkin any chance to buy you a coffee?(Paterson etc.) so community able to donate for your efforts. Thank you!

Added a buymeacoffee link to my profile (do consider first funding the Llama.cpp project itself, though!)

I send a coffee also.

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

GGML_ABORT("not enough space in the context's memory pool");

Probably there are too many nodes on cgraph, try increasing the limit via llama_context::graph_max_nodes()

Comment on lines 19054 to 19056
Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these ggml_cont can be removed if Q/gate are separated. ggml_cont is not recommended when dealing with big tensors

Copy link
Collaborator

@CISC CISC Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually none of these need ggml_cont, Q is 3D already, Q/K are RoPEd so can be views and V can also be a 3D view now.

Edit: sorry, not quite true about V, only if QKV is fused, the weird gate fuse threw me off. Nevertheless, K/V are already contiguous at this point.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

Copy link
Collaborator

@CISC CISC Sep 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

Are you sure? AFAIK those issues are fixed.

Edit: Also, if there still are issues they will never get fixed if we work around them. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that Q is non-contiguous and ggml_rope(_ext) does not work very well with non-cont tensors, it's still buggy on certain backends

I think all of these cases are fixed now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an impl of 2D rope that relies on ggml_view: https://github.com/ngxson/ggml-easy/blob/f56e5e499b1f21a4aae73010e9d9582840428457/demo/2d-rope.cpp

It works on CPU and Metal, but doesn't work on CUDA/Vulkan. Couldn't tested on other backends, but feel free to make a PR to address this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that seems to work. sorry @pwilkin you will need to manually revert the change where I split Q/gate. the tensor shape for Q will be:

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);

layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape of LLM_TENSOR_ATTN_Q and LLM_TENSOR_SSM_OUT should not contain n_ff

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

^ proposed fix for the 3 comments above: 46110e0

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 20, 2025

@ngxson Thanks, scale_bias was one op I was missing in my endeavors :>

I got an LLM to rewrite the internal delta into tensor logic. After a day of manually fixing that crap, I think I understand it enough to rewrite it myself ;)

@ngxson
Copy link
Collaborator

ngxson commented Sep 20, 2025

Honestly I would prefer taking time to understand the mamba/ssm implementation then writing the code manually. Code written by LLM are mostly attempts for 1-to-1 translation from pytorch --> GGML which looks quite confusing

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 20, 2025

Honestly I would prefer taking time to understand the mamba/ssm implementation then writing the code manually. Code written by LLM are mostly attempts for 1-to-1 translation from pytorch --> GGML which looks quite confusing

Yeah, for me getting a rough outline then going over it manually is the best way to learn :)

I tried the "one-to-one" approach and ended up with a graph that wouldn't fit in 16 GB of RAM for a 500M model...

@pwilkin
Copy link
Collaborator Author

pwilkin commented Sep 20, 2025

Aight, I cleaned up the main graph calculation, now I have to figure out how to include conv_states_all in my delta_net function in order to not get the memory error.

@bartowski1182
Copy link
Contributor

I'm working on imatrix quants, takes a hot minute but I can also make it public before it's fully ready

@pwilkin
Copy link
Collaborator Author

pwilkin commented Nov 28, 2025

@LingyeZ

Can you create a quantization for iq4_xs? My memory is not sufficient to allow me to implement bf16→iq4_xs quantization

ilintar/Qwen3-Next-80B-A3B-Instruct-GGUF - did them a long time ago :)

@lovedheart
Copy link
Contributor

If someone has interested in low-bit model, may try https://huggingface.co/lovedheart/Qwen3-Next-80B-A3B-Instruct-GGUF. I normally tried at least one real-world problem for testing with a complete llama-cli log attached. Any feedback appreciated.

@Mithras
Copy link

Mithras commented Nov 28, 2025

I'm able to run qwen3next on my 3090+5090, thank you so much!
I've seen claims that it's supposed to be 10x faster than qwen3 32b which is clearly not the case. I only get around 50 tps at best. I can probably run qwen3 32b faster than that...
What else is missing to make it fast on cuda? Any existing PRs I can try?

@M3l-Idk
Copy link

M3l-Idk commented Nov 28, 2025

I just downloaded a source code of llama.cpp and built it, but it seems like i still can not load the model. (Unable to load the model error).;
cmake -B build -DGGML_VULKAN=1 ⌂ 22:19
cmake --build build --config Release

arch linux, unsloth's quantisations

@bartowski1182
Copy link
Contributor

I just downloaded a source code of llama.cpp and built it, but it seems like i still can not load the model. (Unable to load the model error).; cmake -B build -DGGML_VULKAN=1 ⌂ 22:19 cmake --build build --config Release

arch linux, unsloth's quantisations

Just tried with the model I made with vulkan build and it had no error on Fedora with my AMD 395

Can give it a shot to double check:

https://huggingface.co/bartowski/Qwen_Qwen3-Next-80B-A3B-Thinking-GGUF

@bartowski1182
Copy link
Contributor

bartowski1182 commented Nov 28, 2025

@M3l-Idk for the record though I tried with unsloth's just now and also got no error, so may be on your end, you got any more info?

@M3l-Idk
Copy link

M3l-Idk commented Nov 28, 2025

@M3l-Idk for the record though I tried with unsloth's just now and also got no error, so may be on your end, you got any more info?

Not really, i dont have time today to debug it. Just wanted to know if its just me or everyone has this problem. Thank you anyway.

@mrfrosty009
Copy link

@bartowski1182 Im not sure if the right place to ask, but getting "llama_model_load: error loading model: error loading model architecture: unknown model architecture: 'qwen3next'" error with unsloth

D:\llama.cpp\llama.cpp>llama-cli -m "D:\llama.cpp\llama.cpp\models\Qwen3-Next-80B-A3B-Instruct-Q2_K.gguf" --n-gpu-layers 30
load_backend: loaded RPC backend from C:\Users\Mr.Frosty\AppData\Local\Microsoft\WinGet\Packages\ggml.llamacpp_Microsoft.Winget.Source_8wekyb3d8bbwe\ggml-rpc.dll
ggml_vulkan: Found 2 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4080 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
ggml_vulkan: 1 = Intel(R) Graphics (Intel Corporation) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none
load_backend: loaded Vulkan backend from C:\Users\Mr.Frosty\AppData\Local\Microsoft\WinGet\Packages\ggml.llamacpp_Microsoft.Winget.Source_8wekyb3d8bbwe\ggml-vulkan.dll
load_backend: loaded CPU backend from C:\Users\Mr.Frosty\AppData\Local\Microsoft\WinGet\Packages\ggml.llamacpp_Microsoft.Winget.Source_8wekyb3d8bbwe\ggml-cpu-alderlake.dll
build: 7170 (e509411) with clang version 19.1.5 for x86_64-pc-windows-msvc
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device Vulkan0 (NVIDIA GeForce RTX 4080) (0000:01:00.0) - 13805 MiB free
llama_model_loader: loaded meta data with 49 key-value pairs and 807 tensors from D:\llama.cpp\llama.cpp\models\Qwen3-Next-80B-A3B-Instruct-Q2_K.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv 0: general.architecture str = qwen3next
llama_model_loader: - kv 1: general.type str = model
llama_model_loader: - kv 2: general.sampling.top_k i32 = 20
llama_model_loader: - kv 3: general.sampling.top_p f32 = 0.800000
llama_model_loader: - kv 4: general.sampling.temp f32 = 0.700000
llama_model_loader: - kv 5: general.name str = Qwen3-Next-80B-A3B-Instruct
llama_model_loader: - kv 6: general.finetune str = Instruct
llama_model_loader: - kv 7: general.basename str = Qwen3-Next-80B-A3B-Instruct
llama_model_loader: - kv 8: general.quantized_by str = Unsloth
llama_model_loader: - kv 9: general.size_label str = 80B-A3B
llama_model_loader: - kv 10: general.license str = apache-2.0
llama_model_loader: - kv 11: general.license.link str = https://huggingface.co/Qwen/Qwen3-Nex...
llama_model_loader: - kv 12: general.repo_url str = https://huggingface.co/unsloth
llama_model_loader: - kv 13: general.base_model.count u32 = 1
llama_model_loader: - kv 14: general.base_model.0.name str = Qwen3 Next 80B A3B Instruct
llama_model_loader: - kv 15: general.base_model.0.organization str = Qwen
llama_model_loader: - kv 16: general.base_model.0.repo_url str = https://huggingface.co/Qwen/Qwen3-Nex...
llama_model_loader: - kv 17: general.tags arr[str,2] = ["unsloth", "text-generation"]
llama_model_loader: - kv 18: qwen3next.block_count u32 = 48
llama_model_loader: - kv 19: qwen3next.context_length u32 = 262144
llama_model_loader: - kv 20: qwen3next.embedding_length u32 = 2048
llama_model_loader: - kv 21: qwen3next.feed_forward_length u32 = 5120
llama_model_loader: - kv 22: qwen3next.attention.head_count u32 = 16
llama_model_loader: - kv 23: qwen3next.attention.head_count_kv u32 = 2
llama_model_loader: - kv 24: qwen3next.rope.freq_base f32 = 10000000.000000
llama_model_loader: - kv 25: qwen3next.attention.layer_norm_rms_epsilon f32 = 0.000001
llama_model_loader: - kv 26: qwen3next.expert_used_count u32 = 10
llama_model_loader: - kv 27: qwen3next.attention.key_length u32 = 256
llama_model_loader: - kv 28: qwen3next.attention.value_length u32 = 256
llama_model_loader: - kv 29: qwen3next.expert_count u32 = 512
llama_model_loader: - kv 30: qwen3next.expert_feed_forward_length u32 = 512
llama_model_loader: - kv 31: qwen3next.expert_shared_feed_forward_length u32 = 512
llama_model_loader: - kv 32: qwen3next.ssm.conv_kernel u32 = 4
llama_model_loader: - kv 33: qwen3next.ssm.state_size u32 = 128
llama_model_loader: - kv 34: qwen3next.ssm.group_count u32 = 16
llama_model_loader: - kv 35: qwen3next.ssm.time_step_rank u32 = 32
llama_model_loader: - kv 36: qwen3next.ssm.inner_size u32 = 4096
llama_model_loader: - kv 37: qwen3next.rope.dimension_count u32 = 64
llama_model_loader: - kv 38: tokenizer.ggml.model str = gpt2
llama_model_loader: - kv 39: tokenizer.ggml.pre str = qwen2
llama_model_loader: - kv 40: tokenizer.ggml.tokens arr[str,151936] = ["!", """, "#", "$", "%", "&", "'", ...
llama_model_loader: - kv 41: tokenizer.ggml.token_type arr[i32,151936] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv 42: tokenizer.ggml.merges arr[str,151387] = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "Ġ t",...
llama_model_loader: - kv 43: tokenizer.ggml.eos_token_id u32 = 151645
llama_model_loader: - kv 44: tokenizer.ggml.padding_token_id u32 = 151654
llama_model_loader: - kv 45: tokenizer.ggml.add_bos_token bool = false
llama_model_loader: - kv 46: tokenizer.chat_template str = {%- if tools %}\n {{- '<|im_start|>...
llama_model_loader: - kv 47: general.quantization_version u32 = 2
llama_model_loader: - kv 48: general.file_type u32 = 10
llama_model_loader: - type f32: 313 tensors
llama_model_loader: - type q2_K: 229 tensors
llama_model_loader: - type q3_K: 60 tensors
llama_model_loader: - type q4_K: 108 tensors
llama_model_loader: - type q6_K: 49 tensors
llama_model_loader: - type bf16: 48 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type = Q2_K - Medium
print_info: file size = 27.16 GiB (2.93 BPW)
llama_model_load: error loading model: error loading model architecture: unknown model architecture: 'qwen3next'
llama_model_load_from_file_impl: failed to load model
common_init_from_params: failed to load model 'D:\llama.cpp\llama.cpp\models\Qwen3-Next-80B-A3B-Instruct-Q2_K.gguf', try reducing --n-gpu-layers if you're running out of VRAM
main: error: unable to load model

@danielhanchen
Copy link
Contributor

danielhanchen commented Nov 29, 2025

@mrfrosty009 You need to rebuild llama.cpp from source! See https://docs.unsloth.ai/models/qwen3-next#llama.cpp-run-qwen3-next-80b-a3b-instruct-tutorial for llama.cpp rebuilding for Unsloth related quants

Also the official llama.cpp build docs at https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md is recommended!

@ddh0
Copy link
Contributor

ddh0 commented Nov 29, 2025

@mrfrosty009 You need to rebuild llama.cpp from source! See https://docs.unsloth.ai/models/qwen3-next#llama.cpp-run-qwen3-next-80b-a3b-instruct-tutorial for llama.cpp rebuilding

Or, y'know, they could use the actual build instructions from this repo instead of your third-party guide :)

https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md

@danielhanchen
Copy link
Contributor

@ddh0 Oh apologies I linked it because they're specifically using the Unsloth quant - but yes using the official llama.cpp build instructions is recommend! I edited my comment as well

@mrfrosty009
Copy link

@ddh0 @danielhanchen Thanks! Just couple of min ago I noticed that it was linking to old llama folder on another disk, not the one I wanted. After typing right folder, it worked, thanks!

@M3l-Idk
Copy link

M3l-Idk commented Nov 29, 2025

@pwilkin i just wanted to say that you did a great job in this pr

@juanml82
Copy link

Guys, I genuinely appreciate all the valuable time you've put into this, but for whatever reasons, there is some problem with this implementation. I'm comparing using fastllm, in both instances with an instruct Q4 quant (fastllm uses this one, which doesn't use gguf https://huggingface.co/fastllm/Qwen3-Next-80B-A3B-Instruct-UD-Q4_K_M/tree/main ), same temperature, top_k etc (--temp 0.6 --top-k 20 --top-p 0.95 but there is minium p that I can see in fastllm). When prompted "You must reason step by step, but do NOT output the steps — only the final answer.

A dragon has 3 boxes:

  • Box A always contains a truth.
  • Box B always contains a lie.
  • Box C may contain either.

The dragon makes three statements:

  1. “Box A and Box B contain the same type of item.”
  2. “If Box C contains a truth, then Box A contains a lie.”
  3. “Box C contains a lie.”

Based only on these statements, determine with certainty which boxes contain truth or lie. If it cannot be fully determined, output: “Undetermined”.

Final answer only." Both implementation expose what they are thinking (they are not the thinking models), but the llama.cpp implementation reaches the conclusion "Undetermined" in 6580 words while the fastllm implementation accomplishes the task in 2235 words (I don't think fastllm exposes the token count). This does not happen on shorter prompts (e.g., joke explanation), so it appears specifically in longer reasoning chains. Not saying it’s “dumb,” but the verbosity difference suggests something may be off internally.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 1, 2025

@juanml82 Given that (a) you're doing non-greedy decoding (b) you have not fixed a seed (c) both answers are correct it's much more likely that the difference you're reporting is based on the random seed than on anything in the implementation details.

@M3l-Idk
Copy link

M3l-Idk commented Dec 1, 2025

there is one issue that im having tho, no matter which unsloth's quantization i use, after prompting "create a flappy bird game in html" model gets stuck while trying to code some path to a non existing png file, it just starts hallucinating random numbers as the image's name.

@M3l-Idk
Copy link

M3l-Idk commented Dec 1, 2025

and the code is always the same if i remember correctly

@IIIIIllllIIIIIlllll
Copy link

IIIIIllllIIIIIlllll commented Dec 2, 2025

there is one issue that im having tho, no matter which unsloth's quantization i use, after prompting "create a flappy bird game in html" model gets stuck while trying to code some path to a non existing png file, it just starts hallucinating random numbers as the image's name.

try https://chat.qwen.ai/, the same issue:

image image

The model entered an infinite loop and eventually stopped outputting.

@jdvpro
Copy link

jdvpro commented Dec 2, 2025

CPU-only benchmark: Qwen3-Next vs Qwen3 MoE performance

Hardware: AMD EPYC 9454P (48c/96t, Zen 4), DDR5-4800 12-channel, 377GB RAM

model size params backend threads pp512 tg128
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CPU 24 236.31 t/s 63.10 t/s
gpt-oss-120B Q4_0 60.87 GiB 116.83 B CPU 24 108.56 t/s 35.75 t/s
qwen3next 80B.A3B Q4_0 41.98 GiB 79.67 B CPU 24 73.74 t/s 11.76 t/s

Qwen3-Next shows ~5x slower tg than Qwen3-30B-A3B despite both having 3B active parameters. Even gpt-oss-120B with 5B active params runs 3x faster on tg.

Thanks for all the effort on this either way!

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 2, 2025

@jdvpro oh, we'll get to optimizing it on CPU as well, don't worry.

@LynxPDA
Copy link

LynxPDA commented Dec 3, 2025

CPU-only benchmark: Qwen3-Next vs Qwen3 MoE performance

Hardware: AMD EPYC 9454P (48c/96t, Zen 4), DDR5-4800 12-channel, 377GB RAM
model size params backend threads pp512 tg128
qwen3moe 30B.A3B Q4_0 16.11 GiB 30.53 B CPU 24 236.31 t/s 63.10 t/s
gpt-oss-120B Q4_0 60.87 GiB 116.83 B CPU 24 108.56 t/s 35.75 t/s
qwen3next 80B.A3B Q4_0 41.98 GiB 79.67 B CPU 24 73.74 t/s 11.76 t/s

Qwen3-Next shows ~5x slower tg than Qwen3-30B-A3B despite both having 3B active parameters. Even gpt-oss-120B with 5B active params runs 3x faster on tg.

Thanks for all the effort on this either way!

It seems that pp performance on the CPU should be at least 3 times higher.

On the Ryzen AI Max+ 395, I get these results with build: ab6726e (7227)

CPU (Vulkan -ngl 0)

./llama-bench -m /home/lynx/llama-swap/models/Qwen3-Next-80B-A3B-Instruct-Q4_K_M.gguf -fa 1 -r 1 -ngl 0
load_backend: loaded RPC backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-rpc.so
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon 8060S (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-vulkan.so
load_backend: loaded CPU backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-cpu-icelake.so

model size params backend ngl fa test t/s
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 0 1 pp512 134.41 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 0 1 tg128 14.65 ± 0.00

build: ab6726e (7227)

CPU (pure CPU backend)

./llama-bench -m /home/lynx/llama-swap/models/Qwen3-Next-80B-A3B-Instruct-Q4_K_M.gguf -fa 1
load_backend: loaded RPC backend from /home/lynx/tmp/tmp/libggml-rpc.so
load_backend: loaded CPU backend from /home/lynx/tmp/tmp/libggml-cpu-icelake.so

model size params backend threads fa test t/s
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B CPU 16 1 pp512 101.56 ± 1.59
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B CPU 16 1 tg128 17.32 ± 0.08

build: ab6726e (7227)

Vulkan

./llama-bench -m /home/lynx/llama-swap/models/Qwen3-Next-80B-A3B-Instruct-Q4_K_M.gguf -fa 1 -r 1 -ngl 999
load_backend: loaded RPC backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-rpc.so
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon 8060S (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-vulkan.so
load_backend: loaded CPU backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-cpu-icelake.so

model size params backend ngl fa test t/s
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 313.90 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 36.07 ± 0.00

build: ab6726e (7227)

Update: Added more relevant values ​​with pure CPU backend

@LynxPDA
Copy link

LynxPDA commented Dec 3, 2025

It's also interesting to note that, thanks to its architecture, Qwen3 Next's performance and tg speeds don't drop as much as Qwen3 30b (Vulkan backend) as the context increases.

And starting with a context of 32768, Qwen3 30b doesn't show any significant speed advantages.

Vulkan

./llama-bench -m /home/lynx/llama-swap/models/Qwen3-Next-80B-A3B-Instruct-Q4_K_M.gguf,/home/lynx/llama-swap/models/Qwen3-Coder-30B-A3B-Instruct-UD-Q4_K_XL.gguf -fa 1 -r 1 -ngl 999 -d 512,2048,4096,8192,16384,32768,65536
load_backend: loaded RPC backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-rpc.so
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon 8060S (RADV GFX1151) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-vulkan.so
load_backend: loaded CPU backend from /home/lynx/llama-swap/llama.cpp/vulkan/libggml-cpu-icelake.so

model size params backend ngl fa test t/s
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d512 307.01 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d512 36.25 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d2048 301.11 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d2048 35.71 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d4096 276.94 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d4096 35.68 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d8192 252.42 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d8192 34.28 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d16384 187.12 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d16384 33.31 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d32768 140.82 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d32768 30.78 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 pp512 @ d65536 95.80 ± 0.00
qwen3next ?B Q4_K - Medium 45.10 GiB 79.67 B Vulkan 999 1 tg128 @ d65536 26.52 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d512 822.34 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d512 86.51 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d2048 684.52 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d2048 80.05 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d4096 545.75 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d4096 72.61 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d8192 384.72 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d8192 62.98 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d16384 183.72 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d16384 50.16 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d32768 94.64 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d32768 35.33 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 pp512 @ d65536 46.86 ± 0.00
qwen3moe 30B.A3B Q4_K - Medium 16.45 GiB 30.53 B Vulkan 999 1 tg128 @ d65536 21.86 ± 0.00

build: ab6726e (7227)

pp tg

However, if Qwen3 Next's performance and tg speeds can be further increased over time through optimizations, that would be incredible.

Thank you again to everyone for the opportunity to run Qwen3 Next locally!

P.S. Sorry for the spam.

Updated: Added graphs for clarity

@M3l-Idk
Copy link

M3l-Idk commented Dec 3, 2025

I dont know if this model is supposed to do that, but it seems like the more text it generates the faster it gets over time, like it can even get twice as fast over time than in the beggining.

@theo77186
Copy link
Contributor

I dont know if this model is supposed to do that, but it seems like the more text it generates the faster it gets over time, like it can even get twice as fast over time than in the beggining.

This could happen in the first few runs if the model isn't entirely loaded on GPU as llama.cpp loads the model into RAM. Otherwise it shouldn't happen at all. Text generation at long context get much less slowdown compared to standard attention models, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning model Model specific Nvidia GPU Issues specific to Nvidia GPUs python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Qwen3-Next support