Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/server/README-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ For detailed instructions, see the [test documentation](./tests/README.md).
- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216
- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362
- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470
- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808



Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
using json = nlohmann::ordered_json;

#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)

#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
Expand Down
181 changes: 105 additions & 76 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;

// idx of draft tokens in the main batch
// non-empty if we went to evaluate draft tokens
// ref: https://github.com/ggml-org/llama.cpp/pull/17808
std::vector<int32_t> i_batch_dft;

std::vector<completion_token_output> generated_token_probs;

bool has_next_token = true;
Expand Down Expand Up @@ -150,7 +155,8 @@ struct server_slot {

struct common_sampler * smpl = nullptr;

llama_token sampled;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;

// stats
size_t n_sent_text = 0; // number of sent text character
Expand Down Expand Up @@ -180,6 +186,8 @@ struct server_slot {
stopping_word = "";
n_sent_text = 0;

drafted.clear();
i_batch_dft.clear();
generated_tokens.clear();
generated_token_probs.clear();
json_schema = json();
Expand Down Expand Up @@ -255,6 +263,31 @@ struct server_slot {
generated_token_probs.push_back(token);
}

int get_n_draft_max() const {
if (!can_speculate()) {
return 0;
}

// determine the max draft that fits the current slot state
int n_draft_max = task->params.speculative.n_max;

// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2);

if (n_remaining > 0) {
n_draft_max = std::min(n_draft_max, n_remaining - 1);
}

SLT_DBG(*this, "max possible draft: %d\n", n_draft_max);

if (n_draft_max < task->params.speculative.n_min) {
SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min);
n_draft_max = 0;
}
return n_draft_max;
}

// note: a slot can also be either a parent or a child
bool is_parent() const {
return is_processing() && task->n_children > 0;
Expand Down Expand Up @@ -353,8 +386,7 @@ struct server_slot {

if (n_draft_total > 0) {
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
SLT_INF(*this,
"\n"
SLT_CNT(*this,
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
draft_ratio, n_draft_accepted, n_draft_total
);
Expand Down Expand Up @@ -1774,14 +1806,57 @@ struct server_context_impl {
continue;
}

slot.i_batch = batch.n_tokens;
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}

common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);

// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);

if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
} else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();

// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
} else {
// no speculative decoding
slot.i_batch = batch.n_tokens;

slot.prompt.tokens.push_back(slot.sampled);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
slot.prompt.tokens.push_back(slot.sampled);

SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
}

// process in chunks of params.n_batch
Expand Down Expand Up @@ -2345,6 +2420,10 @@ struct server_context_impl {
// on successful decode, restore the original batch size
n_batch = llama_n_batch(ctx);

// technically, measuring the time here excludes the sampling time for the last batch
// but on the other hand, we don't want to do too many system calls to measure the time, so it's ok
const int64_t t_current = ggml_time_us();

for (auto & slot : slots) {
// may need to copy state to other slots
if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) {
Expand Down Expand Up @@ -2399,6 +2478,10 @@ struct server_context_impl {
continue; // continue loop of slots
}

if (slot.i_batch_dft.size() > 0) {
continue; // sample using speculative decoding
}

const int tok_idx = slot.i_batch - i;

llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
Expand All @@ -2409,8 +2492,6 @@ struct server_context_impl {

slot.n_decoded += 1;

const int64_t t_current = ggml_time_us();

if (slot.n_decoded == 1) {
slot.t_start_generation = t_current;
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
Expand Down Expand Up @@ -2439,84 +2520,32 @@ struct server_context_impl {
}
}

// do speculative decoding
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
// speculative decoding - main model sample and accept
for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) {
continue;
}

if (slot.state != SLOT_STATE_GENERATING) {
continue;
}

if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}

// determine the max draft that fits the current slot state
int n_draft_max = slot.task->params.speculative.n_max;

// note: slot.prompt is not yet expanded with the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2);

if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}

SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);

if (n_draft_max < slot.task->params.speculative.n_min) {
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min);

continue;
}

llama_token id = slot.sampled;

struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;

const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);

// ignore small drafts
if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);

continue;
}

// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();

// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true);

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true);
}

SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);

llama_decode(ctx, slot.batch_spec);
size_t n_draft = slot.drafted.size();

// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear();
slot.drafted.clear();

slot.n_decoded += ids.size();

slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;

// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;

slot.prompt.tokens.push_back(id);
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);

// add accepted tokens to the prompt
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token

llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);

Expand All @@ -2539,7 +2568,7 @@ struct server_context_impl {
}
}

SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens());
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens());
}
}

Expand Down
Loading