mirror of
https://github.com/mudler/LocalAI.git
synced 2025-04-21 09:31:29 +00:00
wip reranking llama.cpp
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
61cc76c455
commit
01e2e3dbc3
@ -217,6 +217,7 @@ struct llama_client_slot
|
||||
|
||||
bool infill = false;
|
||||
bool embedding = false;
|
||||
bool reranker = false;
|
||||
bool has_next_token = true;
|
||||
bool truncated = false;
|
||||
bool stopped_eos = false;
|
||||
@ -1413,7 +1414,54 @@ struct llama_server_context
|
||||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
|
||||
void send_rerank(llama_client_slot &slot, const llama_batch & batch)
|
||||
{
|
||||
task_result res;
|
||||
res.id = slot.task_id;
|
||||
res.multitask_id = slot.multitask_id;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
|
||||
float score = -1e6f; // Default score if we fail to get embeddings
|
||||
|
||||
if (!params.rerank)
|
||||
{
|
||||
LOG_WARNING("reranking disabled", {
|
||||
{"params.rerank", params.rerank},
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
if (embd == NULL) {
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
}
|
||||
|
||||
if (embd == NULL) {
|
||||
LOG("failed to get embeddings");
|
||||
continue;
|
||||
}
|
||||
|
||||
score = embd[0];
|
||||
}
|
||||
}
|
||||
|
||||
// Format result as JSON similar to the embedding function
|
||||
res.result_json = json
|
||||
{
|
||||
{"score", score},
|
||||
{"tokens", slot.n_prompt_tokens}
|
||||
};
|
||||
|
||||
queue_results.send(res);
|
||||
}
|
||||
|
||||
void request_completion(int task_id, json data, bool infill, bool embedding, bool rerank, int multitask_id)
|
||||
{
|
||||
task_server task;
|
||||
task.id = task_id;
|
||||
@ -1421,6 +1469,7 @@ struct llama_server_context
|
||||
task.data = std::move(data);
|
||||
task.infill_mode = infill;
|
||||
task.embedding_mode = embedding;
|
||||
task.reranking_mode = rerank;
|
||||
task.type = TASK_TYPE_COMPLETION;
|
||||
task.multitask_id = multitask_id;
|
||||
|
||||
@ -1552,7 +1601,7 @@ struct llama_server_context
|
||||
subtask_data["prompt"] = subtask_data["prompt"][i];
|
||||
|
||||
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
|
||||
request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multiprompt_task.reranking_mode, multitask_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1591,6 +1640,7 @@ struct llama_server_context
|
||||
|
||||
slot->infill = task.infill_mode;
|
||||
slot->embedding = task.embedding_mode;
|
||||
slot->reranker = task.reranking_mode;
|
||||
slot->task_id = task.id;
|
||||
slot->multitask_id = task.multitask_id;
|
||||
|
||||
@ -2034,6 +2084,14 @@ struct llama_server_context
|
||||
continue;
|
||||
}
|
||||
|
||||
if (slot.reranker)
|
||||
{
|
||||
send_rerank(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue;
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const llama_token id = common_sampler_sample(slot.ctx_sampling, ctx, slot.i_batch - i);
|
||||
|
||||
|
1
backend/cpp/llama/utils.hpp
vendored
1
backend/cpp/llama/utils.hpp
vendored
@ -61,6 +61,7 @@ struct task_server {
|
||||
json data;
|
||||
bool infill_mode = false;
|
||||
bool embedding_mode = false;
|
||||
bool reranking_mode = false;
|
||||
int multitask_id = -1;
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user