wip reranking llama.cpp

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2025-03-07 19:30:34 +01:00
parent 61cc76c455
commit 01e2e3dbc3
2 changed files with 61 additions and 2 deletions

View File

@ -217,6 +217,7 @@ struct llama_client_slot
bool infill = false;
bool embedding = false;
bool reranker = false;
bool has_next_token = true;
bool truncated = false;
bool stopped_eos = false;
@ -1413,7 +1414,54 @@ struct llama_server_context
queue_results.send(res);
}
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
void send_rerank(llama_client_slot &slot, const llama_batch & batch)
{
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
res.error = false;
res.stop = true;
float score = -1e6f; // Default score if we fail to get embeddings
if (!params.rerank)
{
LOG_WARNING("reranking disabled", {
{"params.rerank", params.rerank},
});
}
else
{
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
LOG("failed to get embeddings");
continue;
}
score = embd[0];
}
}
// Format result as JSON similar to the embedding function
res.result_json = json
{
{"score", score},
{"tokens", slot.n_prompt_tokens}
};
queue_results.send(res);
}
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
{
task_server task;
task.id = task_id;
@ -1421,6 +1469,7 @@ struct llama_server_context
task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding;
task.reranking_mode = rerank;
task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id;
@ -1552,7 +1601,7 @@ struct llama_server_context
subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
}
}
@ -1591,6 +1640,7 @@ struct llama_server_context
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode;
slot->reranker = task.reranking_mode;
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
@ -2034,6 +2084,14 @@ struct llama_server_context
continue;
}
if (slot.reranker)
{
send_rerank(slot, batch_view);
slot.release();
slot.i_batch = -1;
continue;
}
completion_token_output result;
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);

View File

@ -61,6 +61,7 @@ struct task_server {
json data;
bool infill_mode = false;
bool embedding_mode = false;
bool reranking_mode = false;
int multitask_id = -1;
};