WIP speculative

This commit is contained in:
Ettore Di Giacinto 2025-01-24 10:17:54 +01:00
parent 9a1182fa01
commit b16a01d0bd

View File

@ -22,6 +22,7 @@
#include "backend.grpc.pb.h"
#include "utils.hpp"
#include "sampling.h"
#include "speculative.h"
// include std::regex
#include <cstddef>
#include <thread>
@ -185,12 +186,45 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
return out;
}
struct llama_slot_params {
uint32_t seed = -1; // RNG seed
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_adapter_lora_info> lora;
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
bool ignore_eos = false;
json input_prefix;
json input_suffix;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
};
struct llama_client_slot
{
int id;
int task_id = -1;
struct slot_params params;
struct llama_slot_params params;
common_speculative * spec = nullptr;
llama_batch batch_spec = {};
slot_state state = IDLE;
slot_command command = NONE;
@ -283,6 +317,7 @@ struct llama_client_slot
images.clear();
}
bool has_budget(common_params &global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1)
{
@ -454,6 +489,10 @@ struct llama_server_context
{
llama_model *model = nullptr;
llama_context *ctx = nullptr;
common_init_result llama_init_dft;
llama_context * ctx_dft = nullptr;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
const llama_vocab * vocab = nullptr;
clip_ctx *clp_ctx = nullptr;
@ -502,6 +541,7 @@ struct llama_server_context
}
}
bool load_model(const common_params &params_)
{
params = params_;
@ -545,6 +585,45 @@ struct llama_server_context
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params.speculative.model.empty()) {
LOG("loading draft model '%s'\n", params.speculative.model.c_str());
auto params_dft = params;
params_dft.devices = params.speculative.devices;
params_dft.model = params.speculative.model;
params_dft.n_ctx = params.speculative.n_ctx == 0 ? params.n_ctx / params.n_parallel : params.speculative.n_ctx;
params_dft.n_gpu_layers = params.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get();
if (model_dft == nullptr) {
LOG("failed to load draft model, '%s'\n", params.speculative.model.c_str());
return false;
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
LOG("the draft model '%s' is not compatible with the target model '%s'\n", params.speculative.model.c_str(), params.model.c_str());
return false;
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
// force F16 KV cache for the draft model for extra performance
cparams_dft.type_k = GGML_TYPE_F16;
cparams_dft.type_v = GGML_TYPE_F16;
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}
return true;
}
@ -573,6 +652,22 @@ struct llama_server_context
slot.n_ctx = n_ctx_slot;
slot.n_predict = params.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params.speculative.n_max + 1, 0, 1);
ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (ctx_dft == nullptr) {
LOG("%s", "failed to create draft context\n");
return;
}
slot.spec = common_speculative_init(ctx_dft);
if (slot.spec == nullptr) {
LOG("%s", "failed to create speculator\n");
return;
}
}
LOG_INFO("new slot", {
{"slot_id", slot.id},
{"n_ctx_slot", slot.n_ctx}
@ -681,9 +776,11 @@ struct llama_server_context
}
bool launch_slot_with_data(llama_client_slot* &slot, json data) {
slot_params default_params;
llama_slot_params default_params;
common_params_sampling default_sparams;
default_sparams.speculative = params_base.speculative;
slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false);
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
@ -707,6 +804,15 @@ struct llama_server_context
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
slot->sparams.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
slot->sparams.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
slot->sparams.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
slot->sparams.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
slot->sparams.speculative.n_min = std::max(params.speculative.n_min, 2);
slot->sparams.speculative.n_max = std::max(params.speculative.n_max, 0);
if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", {
@ -2024,6 +2130,97 @@ struct llama_server_context
}
}
// do speculative decoding
for (auto & slot : slots) {
if (!slot.is_processing() || !(ctx_dft && params.speculative.n_max > 0)) {
continue;
}
if (slot.state != PROCESSING) {
continue;
}
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
// note: n_past is not yet increased for the `id` token sampled above
// also, need to leave space for 1 extra token to allow context shifts
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
if (slot.n_remaining > 0) {
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
}
LOG("max possible draft: %d\n", n_draft_max);
if (n_draft_max < slot.params.speculative.n_min) {
LOG("the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
continue;
}
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
if (slot.params.speculative.n_min > (int) draft.size()) {
LOG("ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
continue;
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
LOG("decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.ctx_sampling, ctx, draft);
slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
result.tok = ids[i];
result.text_to_send = common_token_to_piece(ctx, result.tok, params.special);
//result.prob = 1.0f; // set later
// TODO: set result.probs
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
LOG("accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}
LOG_VERBOSE("slots updated", {});
return true;
}
@ -2296,6 +2493,30 @@ static void params_parse(const backend::ModelOptions* request,
params.cpuparams.n_threads = request->threads();
params.n_gpu_layers = request->ngpulayers();
params.n_batch = request->nbatch();
params.speculative.model = request->draftmodel();
// If options is not NULL, parse options
for (int i = 0; request->options()[i] != NULL; i++) {
char *optname = strtok(request->options()[i], ":");
char *optval = strtok(NULL, ":");
if (optval == NULL) {
optval = "true";
}
if (!strcmp(optname, "speculative.n_gpu_layers")) {
params.speculative.n_gpu_layers = std::stoi(optval);
}
if (!strcmp(optname, "speculative.n_ctx")) {
params.speculative.n_ctx = std::stoi(optval);
}
}
if params.speculative.n_gpu_layers == 0 {
params.speculative.n_gpu_layers = params.n_gpu_layers;
}
if params.speculative.n_ctx == 0 {
params.speculative.n_ctx = params.n_ctx;
}
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
//params.n_parallel = 1;
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");