@@ -816,6 +816,11 @@ struct llama_server_context {
816816 slot.sparams .n_probs = json_value (data, " n_probs" , default_sparams.n_probs );
817817 slot.sparams .min_keep = json_value (data, " min_keep" , default_sparams.min_keep );
818818
819+ if (slot.params .cache_prompt && slot.ga_n != 1 ) {
820+ LOG_WARNING (" cache_prompt is not supported with group-attention" , {});
821+ slot.params .cache_prompt = false ;
822+ }
823+
819824 if (slot.n_predict > 0 && slot.params .n_predict > slot.n_predict ) {
820825 // Might be better to reject the request with a 400 ?
821826 LOG_WARNING (" Max tokens to predict exceeds server configuration" , {
@@ -1769,6 +1774,8 @@ struct llama_server_context {
17691774
17701775 slot.n_prompt_tokens_processed = slot.n_prompt_tokens ;
17711776 } else {
1777+ GGML_ASSERT (slot.ga_n == 1 );
1778+
17721779 // push the prompt into the sampling context (do not apply grammar)
17731780 for (auto & token : prompt_tokens) {
17741781 llama_sampling_accept (slot.ctx_sampling , ctx, token, false );
@@ -1783,34 +1790,17 @@ struct llama_server_context {
17831790 }
17841791
17851792 slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past ;
1786-
1787- if (slot.ga_n != 1 ) {
1788- int ga_i = 0 ;
1789- int32_t ga_n = slot.ga_n ;
1790- int32_t ga_w = slot.ga_w ;
1791- int32_t slot_npast = 0 ;
1792- for (int k = 0 ; k < slot.n_past ; ++k) {
1793- while (slot_npast >= ga_i + ga_w) {
1794- const int bd = (ga_w/ga_n)*(ga_n - 1 );
1795- slot_npast -= bd;
1796- ga_i += ga_w/ga_n;
1797- }
1798- slot_npast++;
1799- }
1800- slot.n_past_se = slot_npast;
1801- slot.ga_i = ga_i;
1802- }
1803-
1804- LOG_INFO (" slot progression" , {
1805- { " id_slot" , slot.id },
1806- { " id_task" , slot.id_task },
1807- { " n_past" , slot.n_past },
1808- { " n_past_se" , slot.n_past_se },
1809- { " ga_i" , slot.ga_i },
1810- { " n_prompt_tokens_processed" , slot.n_prompt_tokens_processed }
1811- });
18121793 }
18131794
1795+ LOG_INFO (" slot progression" , {
1796+ { " id_slot" , slot.id },
1797+ { " id_task" , slot.id_task },
1798+ { " n_past" , slot.n_past },
1799+ { " n_past_se" , slot.n_past_se },
1800+ { " ga_i" , slot.ga_i },
1801+ { " n_prompt_tokens_processed" , slot.n_prompt_tokens_processed }
1802+ });
1803+
18141804 slot.cache_tokens = prompt_tokens;
18151805
18161806 if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0 ) {
@@ -1841,15 +1831,13 @@ struct llama_server_context {
18411831 {" to_eval" , tokens_to_str (ctx, slot.cache_tokens .cbegin () + slot.n_past , slot.cache_tokens .cend ())},
18421832 });
18431833
1844- std::vector<llama_token> prefix_tokens = prompt_tokens;
1845-
18461834 int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past ;
18471835
18481836 int32_t ga_i = slot.ga_i ;
18491837 int32_t ga_n = slot.ga_n ;
18501838 int32_t ga_w = slot.ga_w ;
18511839
1852- for (; slot.n_past < (int ) prefix_tokens .size (); ++slot.n_past ) {
1840+ for (; slot.n_past < (int ) prompt_tokens .size (); ++slot.n_past ) {
18531841 if (slot.ga_n != 1 ) {
18541842 while (slot_npast >= ga_i + ga_w) {
18551843 const int bd = (ga_w/ga_n)*(ga_n - 1 );
@@ -1858,7 +1846,7 @@ struct llama_server_context {
18581846 }
18591847 }
18601848
1861- llama_batch_add (batch, prefix_tokens [slot.n_past ], system_tokens.size () + slot_npast, { slot.id }, false );
1849+ llama_batch_add (batch, prompt_tokens [slot.n_past ], system_tokens.size () + slot_npast, { slot.id }, false );
18621850
18631851 slot_npast++;
18641852 }
0 commit comments