From f2f08f8466dda7d69c96921b28ad3f004593600b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Dec 2025 00:50:08 +0100 Subject: [PATCH 1/7] server: improve speed of speculative decoding --- tools/server/server-context.cpp | 164 ++++++++++++++++++-------------- 1 file changed, 92 insertions(+), 72 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index f3f2edc0cc4..9277496eb3c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -93,6 +93,8 @@ struct server_slot { int32_t n_remaining = -1; int32_t i_batch = -1; + std::vector i_batch_dft; // idx of draft tokens in the main batch + int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; @@ -149,7 +151,8 @@ struct server_slot { struct common_sampler * smpl = nullptr; - llama_token sampled; + llama_token sampled; // in speculative mode, this is the last accepted token + llama_tokens drafted; // stats size_t n_sent_text = 0; // number of sent text character @@ -179,6 +182,8 @@ struct server_slot { stopping_word = ""; n_sent_text = 0; + drafted.clear(); + i_batch_dft.clear(); generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -254,6 +259,31 @@ struct server_slot { generated_token_probs.push_back(token); } + int get_n_draft_max() const { + if (!can_speculate()) { + return 0; + } + + // determine the max draft that fits the current slot state + int n_draft_max = task->params.speculative.n_max; + + // note: slot.prompt is not yet expanded with the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + + if (n_remaining > 0) { + n_draft_max = std::min(n_draft_max, n_remaining - 1); + } + + SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < task->params.speculative.n_min) { + SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min); + n_draft_max = 0; + } + return n_draft_max; + } + void release() { if (is_processing()) { GGML_ASSERT(task); @@ -1745,14 +1775,54 @@ struct server_context_impl { continue; } - slot.i_batch = batch.n_tokens; + // generate draft tokens in speculative decoding mode + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + int n_draft_max = slot.get_n_draft_max(); + if (n_draft_max > 0) { + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + + if (slot.task->params.speculative.n_min > (int) draft.size()) { + // ignore small drafts + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + + } else { + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(slot.sampled); + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // add all drafted tokens to the batch + for (size_t i = 0; i < draft.size(); i++) { + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(draft[i]); + } + slot.drafted = std::move(draft); + } + } else { + // no speculative decoding + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); + slot.prompt.tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + } } // process in chunks of params.n_batch @@ -2341,6 +2411,10 @@ struct server_context_impl { continue; // continue loop of slots } + if (slot.i_batch_dft.size() > 0) { + continue; // sample using speculative decoding + } + const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); @@ -2381,84 +2455,30 @@ struct server_context_impl { } } - // do speculative decoding - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch + // speculative decoding - main model sample and accept for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { + if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { continue; } - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.task->params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); + size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted); + slot.i_batch_dft.clear(); + slot.drafted.clear(); slot.n_decoded += ids.size(); // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.prompt.tokens.push_back(id); + // rollback to the state before sampling the draft tokens + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + + // add accepted tokens to the prompt slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + slot.sampled = ids.back(); // last accepted token llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); @@ -2481,7 +2501,7 @@ struct server_context_impl { } } - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens()); } } From cac8d7b24a5a868fe92fbea07f5a5838ed12b03c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Dec 2025 15:52:24 +0100 Subject: [PATCH 2/7] fix small draft case --- tools/server/server-context.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9277496eb3c..3d36c1939d7 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1792,14 +1792,19 @@ struct server_context_impl { const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + // add the sampled token to the batch + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(slot.sampled); + if (slot.task->params.speculative.n_min > (int) draft.size()) { - // ignore small drafts SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + // fallback to normal decoding + slot.i_batch = slot.i_batch_dft[0]; + slot.drafted.clear(); + slot.i_batch_dft.clear(); } else { - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); From 398ae8dbd97384b6f1eff0b47687eb82d30fd037 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 6 Dec 2025 16:00:22 +0100 Subject: [PATCH 3/7] add link to the PR --- tools/server/server-context.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3d36c1939d7..d628553c2c8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -93,8 +93,6 @@ struct server_slot { int32_t n_remaining = -1; int32_t i_batch = -1; - std::vector i_batch_dft; // idx of draft tokens in the main batch - int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; @@ -103,6 +101,11 @@ struct server_slot { std::string generated_text; llama_tokens generated_tokens; + // idx of draft tokens in the main batch + // non-empty if we went to evaluate draft tokens + // ref: https://github.com/ggml-org/llama.cpp/pull/17808 + std::vector i_batch_dft; + std::vector generated_token_probs; bool has_next_token = true; From 084cec955be90e34c3da1fa66d325495d317a8d1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Dec 2025 10:17:56 +0200 Subject: [PATCH 4/7] server : fix generation time measurement --- tools/server/server-context.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index d628553c2c8..72977bdea27 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2385,6 +2385,8 @@ struct server_context_impl { // on successful decode, restore the original batch size n_batch = llama_n_batch(ctx); + const int64_t t_current = ggml_time_us(); + for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { @@ -2433,8 +2435,6 @@ struct server_context_impl { slot.n_decoded += 1; - const int64_t t_current = ggml_time_us(); - if (slot.n_decoded == 1) { slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; @@ -2478,6 +2478,8 @@ struct server_context_impl { slot.n_decoded += ids.size(); + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; From f74d1ee95fabcdbaa35209da74e5208a0a00bbc4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Dec 2025 10:25:21 +0200 Subject: [PATCH 5/7] server : fix draft acceptance logs (add SRV_CNT, SLT_CNT macros) --- tools/server/server-common.h | 2 ++ tools/server/server-context.cpp | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/server/server-common.h b/tools/server/server-common.h index bb04e82b4f5..90edd95a7db 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + " using json = nlohmann::ordered_json; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 72977bdea27..dc54888d8c6 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -376,8 +376,7 @@ struct server_slot { if (n_draft_total > 0) { const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" + SLT_CNT(*this, "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", draft_ratio, n_draft_accepted, n_draft_total ); From 75be6ba0bf22a7829df01b81d7e6d55bfb871331 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 8 Dec 2025 10:34:35 +0200 Subject: [PATCH 6/7] server : add comment --- tools/server/server-context.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index dc54888d8c6..be09d859fbc 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1805,9 +1805,7 @@ struct server_context_impl { slot.i_batch = slot.i_batch_dft[0]; slot.drafted.clear(); slot.i_batch_dft.clear(); - } else { - // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); @@ -2384,6 +2382,8 @@ struct server_context_impl { // on successful decode, restore the original batch size n_batch = llama_n_batch(ctx); + // technically, measuring the time here excludes the sampling time for the last batch + // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok const int64_t t_current = ggml_time_us(); for (auto & slot : slots) { From 0a63bd807ddcb8e1b008add6519f22716bd247c4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 8 Dec 2025 14:30:08 +0100 Subject: [PATCH 7/7] add PR to docs --- tools/server/README-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 67ebe1aafee..df165c34a3c 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -81,6 +81,7 @@ For detailed instructions, see the [test documentation](./tests/README.md). - Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216 - Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362 - Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470 +- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808