From e4fa894153f9f8604fe82b1fa9bc9b2d4b66e653 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 7 Mar 2025 19:29:52 +0100 Subject: [PATCH] fix(llama.cpp): correctly handle embeddings in batches (#4957) Signed-off-by: Ettore Di Giacinto --- backend/cpp/llama/grpc-server.cpp | 36 +++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 0f3b927a..0d17141f 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -1350,7 +1350,7 @@ struct llama_server_context queue_results.send(res); } - void send_embedding(llama_client_slot &slot) + void send_embedding(llama_client_slot &slot, const llama_batch & batch) { task_result res; res.id = slot.task_id; @@ -1372,10 +1372,38 @@ struct llama_server_context else { const float *data = llama_get_embeddings(ctx); - std::vector embedding(data, data + n_embd); + std::vector embd_res(n_embd, 0.0f); + std::vector> embedding; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + LOG("failed to get embeddings"); + + continue; + } + + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + embedding.push_back(embd_res); + } else { + embedding.push_back({ embd, embd + n_embd }); + } + } + + // OAI compat res.result_json = json { - {"embedding", embedding }, + {"embedding", embedding[0] }, }; } queue_results.send(res); @@ -1996,7 +2024,7 @@ struct llama_server_context // prompt evaluated for embedding if (slot.embedding) { - send_embedding(slot); + send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; continue;