Skip to content

Commit aef02b1

Browse files
committed
server : disable cached prompts with self-extend
1 parent 61b6370 commit aef02b1

File tree

2 files changed

+19
-31
lines changed

2 files changed

+19
-31
lines changed

examples/server-embd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async def main():
1313
model_url = "http://127.0.0.1:6900"
1414
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
1515
url= f"{model_url}/embedding",
16-
json= {"content": str(i)*1024}
16+
json= {"content": str(0)*1024}
1717
) for i in range(n)])
1818

1919
for response in responses:

examples/server/server.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,11 @@ struct llama_server_context {
816816
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
817817
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
818818

819+
if (slot.params.cache_prompt && slot.ga_n != 1) {
820+
LOG_WARNING("cache_prompt is not supported with group-attention", {});
821+
slot.params.cache_prompt = false;
822+
}
823+
819824
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
820825
// Might be better to reject the request with a 400 ?
821826
LOG_WARNING("Max tokens to predict exceeds server configuration", {
@@ -1769,6 +1774,8 @@ struct llama_server_context {
17691774

17701775
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
17711776
} else {
1777+
GGML_ASSERT(slot.ga_n == 1);
1778+
17721779
// push the prompt into the sampling context (do not apply grammar)
17731780
for (auto & token : prompt_tokens) {
17741781
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
@@ -1783,34 +1790,17 @@ struct llama_server_context {
17831790
}
17841791

17851792
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
1786-
1787-
if (slot.ga_n != 1) {
1788-
int ga_i = 0;
1789-
int32_t ga_n = slot.ga_n;
1790-
int32_t ga_w = slot.ga_w;
1791-
int32_t slot_npast = 0;
1792-
for (int k = 0; k < slot.n_past; ++k) {
1793-
while (slot_npast >= ga_i + ga_w) {
1794-
const int bd = (ga_w/ga_n)*(ga_n - 1);
1795-
slot_npast -= bd;
1796-
ga_i += ga_w/ga_n;
1797-
}
1798-
slot_npast++;
1799-
}
1800-
slot.n_past_se = slot_npast;
1801-
slot.ga_i = ga_i;
1802-
}
1803-
1804-
LOG_INFO("slot progression", {
1805-
{ "id_slot", slot.id },
1806-
{ "id_task", slot.id_task },
1807-
{ "n_past", slot.n_past },
1808-
{ "n_past_se", slot.n_past_se },
1809-
{ "ga_i", slot.ga_i },
1810-
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
1811-
});
18121793
}
18131794

1795+
LOG_INFO("slot progression", {
1796+
{ "id_slot", slot.id },
1797+
{ "id_task", slot.id_task },
1798+
{ "n_past", slot.n_past },
1799+
{ "n_past_se", slot.n_past_se },
1800+
{ "ga_i", slot.ga_i },
1801+
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
1802+
});
1803+
18141804
slot.cache_tokens = prompt_tokens;
18151805

18161806
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
@@ -1841,15 +1831,13 @@ struct llama_server_context {
18411831
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
18421832
});
18431833

1844-
std::vector<llama_token> prefix_tokens = prompt_tokens;
1845-
18461834
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
18471835

18481836
int32_t ga_i = slot.ga_i;
18491837
int32_t ga_n = slot.ga_n;
18501838
int32_t ga_w = slot.ga_w;
18511839

1852-
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) {
1840+
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
18531841
if (slot.ga_n != 1) {
18541842
while (slot_npast >= ga_i + ga_w) {
18551843
const int bd = (ga_w/ga_n)*(ga_n - 1);
@@ -1858,7 +1846,7 @@ struct llama_server_context {
18581846
}
18591847
}
18601848

1861-
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
1849+
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
18621850

18631851
slot_npast++;
18641852
}

0 commit comments

Comments
 (0)