talk-llama : sync llama.cpp

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-10 10:12:44 +03:00
parent 96eaf46ec6
commit db264d6220
23 changed files with 911 additions and 437 deletions

View File

@ -16,7 +16,6 @@ if (WHISPER_SDL2)
llama-hparams.cpp llama-hparams.cpp
llama-impl.cpp llama-impl.cpp
llama-io.cpp llama-io.cpp
llama-kv-cache.cpp
llama-kv-cache-unified.cpp llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp llama-kv-cache-recurrent.cpp

View File

@ -200,7 +200,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@ -1707,8 +1706,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
std::string LLM_KV::operator()(llm_kv kv) const { std::string LLM_KV::operator()(llm_kv kv) const {
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
if (suffix != nullptr) {
name += ".";
name += suffix;
}
return name;
} }
std::string LLM_TN_IMPL::str() const { std::string LLM_TN_IMPL::str() const {

View File

@ -196,7 +196,6 @@ enum llm_kv {
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE, LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_MID_ID,

View File

@ -2,9 +2,9 @@
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-io.h" #include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include "llama-kv-cache.h"
#include <cinttypes> #include <cinttypes>
#include <cstring> #include <cstring>
@ -123,7 +123,7 @@ llama_context::llama_context(
__func__, n_ctx_per_seq, hparams.n_ctx_train); __func__, n_ctx_per_seq, hparams.n_ctx_train);
} }
if (!params.swa_full && cparams.n_seq_max > 1) { if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
__func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
} }
@ -277,10 +277,9 @@ llama_context::llama_context(
int n_nodes_tg = -1; int n_nodes_tg = -1;
// simulate full KV cache // simulate full KV cache
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
const auto kv_state = kv_self->init_full(); const auto mstate = memory->init_full();
if (!kv_state) { if (!mstate) {
throw std::runtime_error("failed to initialize KV cache"); throw std::runtime_error("failed to initialize KV cache");
} }
@ -288,7 +287,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once // reserve pp graph first so that buffers are only allocated once
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -299,7 +298,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes // reserve with tg graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(1, 1, 1, kv_state.get()); auto * gf = graph_reserve(1, 1, 1, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers"); throw std::runtime_error("failed to allocate compute tg buffers");
} }
@ -310,7 +309,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -419,40 +418,68 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch; return cparams.n_threads_batch;
} }
llama_kv_cache * llama_context::get_kv_self() { llama_memory_t llama_context::get_memory() const {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); return memory.get();
return kv_self;
} }
const llama_kv_cache * llama_context::get_kv_self() const { // deprecated
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); void llama_context::kv_self_defrag_sched() {
return kv_self; if (!memory) {
return;
}
memory_force_optimize = true;
} }
bool llama_context::kv_self_update() { // deprecated
bool llama_context::kv_self_update(bool optimize) {
if (!memory) { if (!memory) {
return false; return false;
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); {
// TODO: remove in the future
optimize |= memory_force_optimize;
memory_force_optimize = false;
if (!kv_self->update(*this)) { const auto mstate = memory->init_update(this, optimize);
// no updates have been performed switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
// no updates need to be performed
return false; return false;
} }
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
return false;
}
}
// if the KV cache did any computation, we have to reserve a new worst-case graph if (!mstate->apply()) {
const auto kv_state = kv_self->init_full(); LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
if (!kv_state) { }
throw std::runtime_error("failed to initialize KV cache"); }
// if the memory module did any computation, we have to reserve a new worst-case graph
{
const auto mstate = memory->init_full();
if (!mstate) {
throw std::runtime_error("failed to initialize memory state");
} }
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
} }
return true; return true;
@ -814,16 +841,17 @@ int llama_context::encode(llama_batch & inp_batch) {
} break; } break;
case LLAMA_POOLING_TYPE_RANK: case LLAMA_POOLING_TYPE_RANK:
{ {
// extract the rerank score - a single float per sequence // extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq; auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0]; const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue; continue;
} }
embd_seq_out[seq_id].resize(1); embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
} }
} break; } break;
case LLAMA_POOLING_TYPE_UNSPECIFIED: case LLAMA_POOLING_TYPE_UNSPECIFIED:
@ -880,10 +908,8 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@ -940,42 +966,49 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs_all = 1; n_outputs_all = 1;
} }
bool did_optimize = false;
// handle any pending defrags/shifts // handle any pending defrags/shifts
kv_self_update(); kv_self_update(false);
llama_memory_state_ptr kv_state; llama_memory_state_ptr mstate;
bool did_defrag = false;
while (true) { while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) { if (!mstate) {
return -2; return -2;
} }
switch (kv_state->get_status()) { switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
} break; } break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
return -2;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE: case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
{ {
if (!did_defrag) { if (!did_optimize) {
did_defrag = true; did_optimize = true;
kv_self->defrag_sched(-1.0f); if (kv_self_update(true)) {
if (kv_self_update()) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
continue; continue;
} }
} }
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
return 1; return 1;
} }
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{ {
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
return -2; return -2;
} }
} }
@ -992,7 +1025,7 @@ int llama_context::decode(llama_batch & inp_batch) {
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
do { do {
const auto & ubatch = kv_state->get_ubatch(); const auto & ubatch = mstate->get_ubatch();
// count the outputs in this u_batch // count the outputs in this u_batch
{ {
@ -1015,11 +1048,14 @@ int llama_context::decode(llama_batch & inp_batch) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status; ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() }; llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0]; const auto & seq_id = ubatch.seq_id[i][0];
@ -1034,7 +1070,7 @@ int llama_context::decode(llama_batch & inp_batch) {
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
llama_kv_self_seq_rm(this, s, pos_min[s], -1); memory->seq_rm(s, pos_min[s], -1);
} }
switch (status) { switch (status) {
@ -1128,7 +1164,7 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
n_outputs_prev += n_outputs; n_outputs_prev += n_outputs;
} while (kv_state->next()); } while (mstate->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith // set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all; n_outputs = n_outputs_all;
@ -1137,7 +1173,7 @@ int llama_context::decode(llama_batch & inp_batch) {
{ {
bool sorted_output = true; bool sorted_output = true;
auto & out_ids = kv_state->out_ids(); auto & out_ids = mstate->out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
@ -1189,11 +1225,6 @@ int llama_context::decode(llama_batch & inp_batch) {
// wait for the computation to finish (automatically done when obtaining the model output) // wait for the computation to finish (automatically done when obtaining the model output)
//synchronize(); //synchronize();
// decide if we need to defrag the kv cache
if (cparams.defrag_thold > 0.0f) {
kv_self->defrag_sched(cparams.defrag_thold);
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation. // overlap with device computation.
ggml_backend_sched_reset(sched.get()); ggml_backend_sched_reset(sched.get());
@ -1810,11 +1841,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
} }
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); if (memory != nullptr) {
if (kv_self != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self->state_write(io); memory->state_write(io);
} }
return io.n_bytes(); return io.n_bytes();
@ -1901,9 +1930,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
if (memory) { if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_read(io);
kv_self->state_read(io);
} }
return io.n_bytes(); return io.n_bytes();
@ -1913,9 +1940,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_write(io, seq_id);
kv_self->state_write(io, seq_id);
} }
return io.n_bytes(); return io.n_bytes();
@ -1925,9 +1950,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_read(io, seq_id);
kv_self->state_read(io, seq_id);
} }
return io.n_bytes(); return io.n_bytes();
@ -2032,9 +2055,7 @@ void llama_context::opt_epoch_iter(
const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->clear(true);
kv_self->clear();
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch; batch.n_tokens = n_batch;
@ -2057,8 +2078,8 @@ void llama_context::opt_epoch_iter(
int64_t n_outputs_all = n_tokens_all; int64_t n_outputs_all = n_tokens_all;
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break; break;
} }
@ -2071,17 +2092,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0; uint32_t pos_batch = 0;
do { do {
const auto & ubatch = kv_state->get_ubatch(); const auto & ubatch = mstate->get_ubatch();
n_outputs = ubatch.n_tokens; n_outputs = ubatch.n_tokens;
if (!kv_state->apply()) { if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
break; break;
} }
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
struct ggml_context * ctx_compute_opt; struct ggml_context * ctx_compute_opt;
{ {
@ -2116,7 +2137,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt); ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens; pos_batch += ubatch.n_tokens;
} while (kv_state->next()); } while (mstate->next());
} }
} }
@ -2277,13 +2298,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model(); return &ctx->get_model();
} }
// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) { llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return ctx->get_kv_self(); return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
} }
// deprecated // deprecated
void llama_kv_self_update(llama_context * ctx) { void llama_kv_self_update(llama_context * ctx) {
ctx->kv_self_update(); ctx->kv_self_update(false);
} }
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@ -2398,13 +2420,118 @@ int32_t llama_apply_adapter_cvec(
return res ? 0 : -1; return res ? 0 : -1;
} }
//
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;
}
mem->clear(data);
}
bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (!mem) {
return true;
}
return mem->seq_rm(seq_id, p0, p1);
}
void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (!mem) {
return;
}
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id) {
if (!mem) {
return;
}
mem->seq_keep(seq_id);
}
void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (!mem) {
return;
}
mem->seq_add(seq_id, p0, p1, delta);
}
void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (!mem) {
return;
}
mem->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id) {
if (!mem) {
return -1;
}
return mem->seq_pos_min(seq_id);
}
llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id) {
if (!mem) {
return -1;
}
return mem->seq_pos_max(seq_id);
}
bool llama_memory_can_shift(llama_memory_t mem) {
if (!mem) {
return false;
}
return mem->get_can_shift();
}
// //
// kv cache // kv cache
// //
// deprecated // deprecated
int32_t llama_kv_self_n_tokens(const llama_context * ctx) { int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); const auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return 0; return 0;
} }
@ -2426,7 +2553,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
// deprecated // deprecated
// note: this is the same as above - will be removed anyway, so it's ok // note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) { int32_t llama_kv_self_used_cells(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); const auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return 0; return 0;
} }
@ -2445,115 +2572,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return res; return res;
} }
// deprecated
void llama_kv_self_clear(llama_context * ctx) { void llama_kv_self_clear(llama_context * ctx) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->clear(); llama_memory_clear(kv, true);
} }
// deprecated
bool llama_kv_self_seq_rm( bool llama_kv_self_seq_rm(
llama_context * ctx, llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return true; return true;
} }
return kv->seq_rm(seq_id, p0, p1); return llama_memory_seq_rm(kv, seq_id, p0, p1);
} }
// deprecated
void llama_kv_self_seq_cp( void llama_kv_self_seq_cp(
llama_context * ctx, llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
} }
// deprecated
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_keep(seq_id); llama_memory_seq_keep(kv, seq_id);
} }
// deprecated
void llama_kv_self_seq_add( void llama_kv_self_seq_add(
llama_context * ctx, llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_add(seq_id, p0, p1, delta); llama_memory_seq_add(kv, seq_id, p0, p1, delta);
} }
// deprecated
void llama_kv_self_seq_div( void llama_kv_self_seq_div(
llama_context * ctx, llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { int d) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_div(seq_id, p0, p1, d); llama_memory_seq_div(kv, seq_id, p0, p1, d);
} }
// deprecated
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return -1; return -1;
} }
return kv->seq_pos_min(seq_id); return llama_memory_seq_pos_min(kv, seq_id);
} }
// deprecated
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return -1; return -1;
} }
return kv->seq_pos_max(seq_id); return llama_memory_seq_pos_max(kv, seq_id);
} }
// deprecated // deprecated
void llama_kv_self_defrag(llama_context * ctx) { void llama_kv_self_defrag(llama_context * ctx) {
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}
// force defrag // force defrag
kv->defrag_sched(-1.0f); ctx->kv_self_defrag_sched();
} }
// deprecated
bool llama_kv_self_can_shift(const llama_context * ctx) { bool llama_kv_self_can_shift(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return false; return false;
} }
return kv->get_can_shift(); return llama_memory_can_shift(kv);
} }
// llama state API // llama state API

View File

@ -13,13 +13,12 @@
#include <vector> #include <vector>
struct llama_model; struct llama_model;
struct llama_kv_cache;
class llama_io_read_i; class llama_io_read_i;
class llama_io_write_i; class llama_io_write_i;
class llama_memory_i; struct llama_memory_i;
class llama_memory_state_i; struct llama_memory_state_i;
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs // init scheduler and compute buffers, reserve worst-case graphs
@ -47,12 +46,12 @@ struct llama_context {
uint32_t n_threads() const; uint32_t n_threads() const;
uint32_t n_threads_batch() const; uint32_t n_threads_batch() const;
llama_kv_cache * get_kv_self(); llama_memory_t get_memory() const;
const llama_kv_cache * get_kv_self() const;
// return true of the KV cache was updated // return true of the KV cache was updated
// TODO: remove // TODO: remove
bool kv_self_update(); bool kv_self_update(bool optimize);
void kv_self_defrag_sched();
enum llama_pooling_type pooling_type() const; enum llama_pooling_type pooling_type() const;
@ -231,6 +230,9 @@ private:
std::unique_ptr<llama_memory_i> memory; std::unique_ptr<llama_memory_i> memory;
// TODO: temporary, until the llama_kv_self_defrag() API is removed
bool memory_force_optimize = false;
// decode output (2-dimensional array: [n_outputs][n_vocab]) // decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr; float * logits = nullptr;

View File

@ -659,6 +659,20 @@ ggml_tensor * llm_graph_context::build_ffn(
cur = ggml_mul(ctx0, x0, x1); cur = ggml_mul(ctx0, x0, x1);
cb(cur, "ffn_mul", il); cb(cur, "ffn_mul", il);
} break; } break;
case LLM_FFN_GEGLU:
{
// Split into two equal parts
int64_t split_point = cur->ne[0] / 2;
// TODO: these conts should not be needed
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
x0 = ggml_gelu(ctx0, x0);
cb(x0, "ffn_gelu", il);
cur = ggml_mul(ctx0, x0, x1);
cb(cur, "ffn_geglu", il);
} break;
} }
if (gate && type_gate == LLM_FFN_PAR) { if (gate && type_gate == LLM_FFN_PAR) {
@ -769,9 +783,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
if (weight_before_ffn) { if (weight_before_ffn) {
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) // repeat cur to [n_embd, n_expert_used, n_tokens]
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
cur = ggml_mul(ctx0, repeated, weights); cur = ggml_mul(ctx0, repeated, weights);
cb(cur, "ffn_moe_weighted", il); cb(cur, "ffn_moe_weighted", il);
} }

View File

@ -17,7 +17,7 @@ struct ggml_tensor;
struct llama_ubatch; struct llama_ubatch;
struct llama_cparams; struct llama_cparams;
class llama_memory_state_i; struct llama_memory_state_i;
class llama_kv_cache_unified_state; class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state; class llama_kv_cache_unified_iswa_state;
@ -36,6 +36,7 @@ enum llm_ffn_op_type {
LLM_FFN_RELU, LLM_FFN_RELU,
LLM_FFN_RELU_SQR, LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU, LLM_FFN_SWIGLU,
LLM_FFN_GEGLU,
}; };
enum llm_ffn_gate_type { enum llm_ffn_gate_type {

View File

@ -1,6 +1,7 @@
#include "llama-kv-cache-recurrent.h" #include "llama-kv-cache-recurrent.h"
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-io.h"
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-model.h" #include "llama-model.h"
@ -116,19 +117,22 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
} }
} }
void llama_kv_cache_recurrent::clear() { void llama_kv_cache_recurrent::clear(bool data) {
for (int32_t i = 0; i < (int32_t) size; ++i) { for (int32_t i = 0; i < (int32_t) size; ++i) {
cells[i].pos = -1; cells[i].pos = -1;
cells[i].seq_id.clear(); cells[i].seq_id.clear();
cells[i].src = -1; cells[i].src = -1;
cells[i].tail = -1; cells[i].tail = -1;
} }
head = 0; head = 0;
used = 0; used = 0;
if (data) {
for (auto & buf : bufs) { for (auto & buf : bufs) {
ggml_backend_buffer_clear(buf.get(), 0); ggml_backend_buffer_clear(buf.get(), 0);
} }
}
} }
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@ -386,6 +390,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
} }
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
GGML_UNUSED(lctx);
GGML_UNUSED(optimize);
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
}
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) { bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
// simply remember the full state because it is very small for this type of cache // simply remember the full state because it is very small for this type of cache
// TODO: optimize // TODO: optimize
@ -419,17 +430,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
return success; return success;
} }
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
GGML_UNUSED(lctx);
// noop
return false;
}
void llama_kv_cache_recurrent::defrag_sched(float thold) {
GGML_UNUSED(thold);
// noop
}
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seqs = ubatch.n_seqs;
@ -726,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
if (!res) { if (!res) {
if (seq_id == -1) { if (seq_id == -1) {
clear(); clear(true);
} else { } else {
seq_rm(seq_id, -1, -1); seq_rm(seq_id, -1, -1);
} }
@ -883,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
return false; return false;
} }
clear(); clear(true);
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
kv_cell & cell = cells[i]; kv_cell & cell = cells[i];

View File

@ -2,7 +2,7 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache.h" #include "llama-memory.h"
#include <set> #include <set>
#include <vector> #include <vector>
@ -13,7 +13,7 @@
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_memory_i {
public: public:
llama_kv_cache_recurrent( llama_kv_cache_recurrent(
const llama_model & model, const llama_model & model,
@ -29,7 +29,17 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override; llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@ -40,22 +50,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
bool prepare(const std::vector<llama_ubatch> & ubatches); bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of kv cells and emplace the ubatch there // find a contiguous slot of kv cells and emplace the ubatch there

View File

@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
hparams.n_swa, hparams.swa_type); hparams.n_swa, hparams.swa_type);
} }
void llama_kv_cache_unified_iswa::clear() { void llama_kv_cache_unified_iswa::clear(bool data) {
kv_base->clear(); kv_base->clear(data);
kv_swa ->clear(); kv_swa ->clear(data);
} }
bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
assert(heads_base.size() == heads_swa.size()); assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
} }
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
bool res = false; return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
res = res | kv_base->update(lctx);
res = res | kv_swa ->update(lctx);
return res;
}
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
kv_base->defrag_sched(thold);
kv_swa ->defrag_sched(thold);
} }
bool llama_kv_cache_unified_iswa::get_can_shift() const { bool llama_kv_cache_unified_iswa::get_can_shift() const {
@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_memory_status status, llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
llama_kv_cache_unified_iswa * kv) : status(status) { state_base = kv->get_base()->init_full();
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); state_swa = kv->get_swa ()->init_full();
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_base = kv->get_base()->init_update(lctx, optimize);
state_swa = kv->get_swa ()->init_update(lctx, optimize);
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
} }
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch, llama_sbatch sbatch,
std::vector<uint32_t> heads_base, std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa, std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches) std::vector<llama_ubatch> ubatches)
: status(status), : status(LLAMA_MEMORY_STATUS_SUCCESS),
sbatch(std::move(sbatch)), sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)) { ubatches(std::move(ubatches)) {
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
}
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
}
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return state_base.get(); return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
} }
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return state_swa.get(); return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
} }

View File

@ -11,7 +11,7 @@
// utilizes two instances of llama_kv_cache_unified // utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_kv_cache { class llama_kv_cache_unified_iswa : public llama_memory_i {
public: public:
llama_kv_cache_unified_iswa( llama_kv_cache_unified_iswa(
const llama_model & model, const llama_model & model,
@ -31,7 +31,19 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override; llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@ -42,24 +54,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
bool get_can_shift() const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@ -86,12 +80,16 @@ public:
// used to create a full-cache state // used to create a full-cache state
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv); llama_kv_cache_unified_iswa * kv);
// used to create an update state
llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize);
// used to create a state from a batch // used to create a state from a batch
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_state(
llama_memory_status status,
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch, llama_sbatch sbatch,
std::vector<uint32_t> heads_base, std::vector<uint32_t> heads_base,
@ -120,7 +118,7 @@ public:
const llama_kv_cache_unified_state * get_swa() const; const llama_kv_cache_unified_state * get_swa() const;
private: private:
const llama_memory_status status; llama_memory_status status;
//llama_kv_cache_unified_iswa * kv; //llama_kv_cache_unified_iswa * kv;
@ -131,6 +129,6 @@ private:
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
std::unique_ptr<llama_kv_cache_unified_state> state_base; llama_memory_state_ptr state_base;
std::unique_ptr<llama_kv_cache_unified_state> state_swa; llama_memory_state_ptr state_swa;
}; };

View File

@ -1,6 +1,7 @@
#include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified.h"
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-io.h"
#include "llama-model.h" #include "llama-model.h"
#include "llama-context.h" #include "llama-context.h"
@ -128,14 +129,16 @@ llama_kv_cache_unified::llama_kv_cache_unified(
} }
} }
void llama_kv_cache_unified::clear() { void llama_kv_cache_unified::clear(bool data) {
cells.reset(); cells.reset();
head = 0; head = 0;
if (data) {
for (auto & buf : bufs) { for (auto & buf : bufs) {
ggml_backend_buffer_clear(buf.get(), 0); ggml_backend_buffer_clear(buf.get(), 0);
} }
}
} }
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@ -149,6 +152,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
p1 = std::numeric_limits<llama_pos>::max(); p1 = std::numeric_limits<llama_pos>::max();
} }
if (seq_id >= 0) {
for (uint32_t i = 0; i < cells.size(); ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells.pos_in(i, p0, p1)) { if (!cells.pos_in(i, p0, p1)) {
continue; continue;
@ -160,6 +164,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
} }
} }
} }
} else {
// match any sequence
for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells.pos_in(i, p0, p1)) {
continue;
}
cells.rm(i);
if (new_head == cells.size()) {
new_head = i;
}
}
}
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
if (new_head != cells.size() && new_head < head) { if (new_head != cells.size() && new_head < head) {
@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, return std::make_unique<llama_kv_cache_unified_state>(
this, std::move(sbatch), std::move(heads), std::move(ubatches)); this, std::move(sbatch), std::move(heads), std::move(ubatches));
} }
llama_memory_state_ptr llama_kv_cache_unified::init_full() { llama_memory_state_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this); return std::make_unique<llama_kv_cache_unified_state>(this);
} }
std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
std::vector<uint32_t> res; bool do_shift = get_has_shift();
defrag_info dinfo;
// see if we need to defrag
{
bool do_defrag = optimize;
const auto thold = lctx->get_cparams().defrag_thold;
if (!do_defrag && thold > 0.0f) {
const auto n_kv = cells.used_max_p1();
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
if (fragmentation > thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
do_defrag = true;
}
}
if (do_defrag) {
dinfo = defrag_prepare(lctx->graph_max_nodes());
}
}
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
}
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
llama_kv_cache_unified::ubatch_heads res;
struct state { struct state {
uint32_t head_old; // old position of the head, before placing the ubatch uint32_t head_old; // old position of the head, before placing the ubatch
@ -359,12 +410,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
return res; return res;
} }
bool llama_kv_cache_unified::update(llama_context & lctx) { bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
bool updated = false; bool updated = false;
auto * sched = lctx.get_sched(); auto * sched = lctx->get_sched();
if (cells.get_has_shift()) { if (do_shift) {
if (!get_can_shift()) { if (!get_can_shift()) {
GGML_ABORT("The current KV cache / model configuration does not support K-shift"); GGML_ABORT("The current KV cache / model configuration does not support K-shift");
} }
@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched); ggml_backend_sched_reset(sched);
auto * gf = lctx.graph_init(); auto * gf = lctx->graph_init();
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
if (!res) { if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
return updated; return updated;
@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
res->set_inputs(nullptr); res->set_inputs(nullptr);
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
return updated; return updated;
} }
@ -401,15 +452,32 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
cells.reset_shift(); cells.reset_shift();
} }
if (do_defrag) { if (!dinfo.empty()) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
if (defrag_prepare(lctx.graph_max_nodes())) { // apply moves:
{
const auto n_kv = dinfo.ids.size();
for (uint32_t i = 0; i < n_kv; ++i) {
assert(dinfo.ids[i] <= n_kv);
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
continue;
}
cells.mv(i, dinfo.ids[i]);
}
// reset the head so we can find the first free slot during the next ubatch
head = 0;
}
ggml_backend_sched_reset(sched); ggml_backend_sched_reset(sched);
auto * gf = lctx.graph_init(); auto * gf = lctx->graph_init();
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
if (!res) { if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
return updated; return updated;
@ -422,7 +490,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
res->set_inputs(nullptr); res->set_inputs(nullptr);
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
return updated; return updated;
} }
@ -430,27 +498,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
updated = true; updated = true;
} }
do_defrag = false;
}
return updated; return updated;
} }
void llama_kv_cache_unified::defrag_sched(float thold) {
const auto n_kv = cells.used_max_p1();
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
do_defrag = true;
}
}
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
@ -597,6 +647,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
return cells.size(); return cells.size();
} }
bool llama_kv_cache_unified::get_has_shift() const {
return cells.get_has_shift();
}
uint32_t llama_kv_cache_unified::get_n_kv() const { uint32_t llama_kv_cache_unified::get_n_kv() const {
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
} }
@ -890,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
const auto & n_embd_head_k = hparams.n_embd_head_k; const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v; //const auto & n_embd_head_v = hparams.n_embd_head_v;
//GGML_ASSERT(kv_self->size == n_ctx);
auto inp = std::make_unique<llm_graph_input_k_shift>(this); auto inp = std::make_unique<llm_graph_input_k_shift>(this);
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
ggml_set_input(inp->k_shift); ggml_set_input(inp->k_shift);
for (const auto & layer : layers) { for (const auto & layer : layers) {
@ -928,10 +980,11 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
const llama_cparams & cparams, const llama_cparams & cparams,
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf) const { ggml_cgraph * gf,
const defrag_info & dinfo) const {
auto res = std::make_unique<llm_graph_result>(); auto res = std::make_unique<llm_graph_result>();
const auto & ids = defrag_info.ids; const auto & ids = dinfo.ids;
#if 0 #if 0
// CPU defrag // CPU defrag
@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
return res; return res;
} }
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
const uint32_t n_layer = layers.size(); const uint32_t n_layer = layers.size();
const uint32_t n_kv = cells.used_max_p1(); const uint32_t n_kv = cells.used_max_p1();
@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
// determine which KV cells to move where // determine which KV cells to move where
// defrag_info res;
// cell i moves to ids[i] auto & ids = res.ids;
//
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
//
auto & ids = defrag_info.ids;
ids.clear();
ids.resize(n_kv, n_kv); ids.resize(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_used; ++i0) { for (uint32_t i0 = 0; i0 < n_used; ++i0) {
@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
// this cell goes to (i0 + nf) // this cell goes to (i0 + nf)
ids[i1] = i0 + nf; ids[i1] = i0 + nf;
// move the cell meta data
cells.mv(i1, i0 + nf);
head = n_used;
if (!cont) { if (!cont) {
n_moves++; n_moves++;
cont = true; cont = true;
@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
} }
if (n_moves == 0) { if (n_moves == 0) {
return false; return {};
} }
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
return true; return res;
} }
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
if (!res) { if (!res) {
if (seq_id == -1) { if (seq_id == -1) {
clear(); clear(true);
} else { } else {
seq_rm(seq_id, -1, -1); seq_rm(seq_id, -1, -1);
} }
@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
return false; return false;
} }
clear(); clear(true);
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos; llama_pos pos;
@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_memory_status status, llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
llama_kv_cache_unified * kv) : status(status), kv(kv) {
n_kv = kv->get_size(); n_kv = kv->get_size();
head = 0; head = 0;
} }
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
if (!do_shift && dinfo.empty()) {
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
}
}
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_sbatch sbatch, llama_sbatch sbatch,
std::vector<uint32_t> heads, llama_kv_cache_unified::ubatch_heads heads,
std::vector<llama_ubatch> ubatches) std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
: status(status), }
kv(kv),
sbatch(std::move(sbatch)),
heads(std::move(heads)),
ubatches(std::move(ubatches)) {
}
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
bool llama_kv_cache_unified_state::apply() { bool llama_kv_cache_unified_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
// no ubatches -> this is a KV cache update
if (ubatches.empty()) {
kv->update(lctx, do_shift, dinfo);
return true;
}
kv->apply_ubatch(heads[i_next], ubatches[i_next]); kv->apply_ubatch(heads[i_next], ubatches[i_next]);
n_kv = kv->get_n_kv(); n_kv = kv->get_n_kv();

View File

@ -2,8 +2,8 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-kv-cells.h" #include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -17,13 +17,26 @@ struct llama_context;
// llama_kv_cache_unified // llama_kv_cache_unified
// //
class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_unified : public llama_memory_i {
public: public:
static uint32_t get_padding(const llama_cparams & cparams); static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache // this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>; using layer_filter_cb = std::function<bool(int32_t il)>;
using ubatch_heads = std::vector<uint32_t>;
struct defrag_info {
bool empty() const {
return ids.empty();
}
// contains information about which cell moves where:
// - cell i moves to ids[i]
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
std::vector<uint32_t> ids;
};
llama_kv_cache_unified( llama_kv_cache_unified(
const llama_model & model, const llama_model & model,
layer_filter_cb && filter, layer_filter_cb && filter,
@ -43,7 +56,19 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override; llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@ -54,24 +79,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
bool get_can_shift() const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@ -83,6 +90,8 @@ public:
uint32_t get_size() const; uint32_t get_size() const;
bool get_has_shift() const;
// //
// graph_build API // graph_build API
// //
@ -103,7 +112,9 @@ public:
// find places for the provided ubatches in the cache, returns the head locations // find places for the provided ubatches in the cache, returns the head locations
// return empty vector on failure // return empty vector on failure
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches); ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch // return the cell position where we can insert the ubatch
// return -1 on failure to find a contiguous slot of kv cells // return -1 on failure to find a contiguous slot of kv cells
@ -133,7 +144,6 @@ private:
ggml_tensor * v; ggml_tensor * v;
}; };
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
@ -160,13 +170,8 @@ private:
// model layer id -> KV cache layer id // model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids; std::unordered_map<int32_t, int32_t> map_layer_ids;
// defrag // return non-empty vector if cells have been moved
struct { defrag_info defrag_prepare(int32_t n_max_nodes) const;
std::vector<uint32_t> ids;
} defrag_info;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
size_t total_size() const; size_t total_size() const;
@ -192,7 +197,8 @@ private:
llm_graph_result_ptr build_graph_defrag( llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams, const llama_cparams & cparams,
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf) const; ggml_cgraph * gf,
const defrag_info & dinfo) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const; void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@ -203,20 +209,29 @@ private:
class llama_kv_cache_unified_state : public llama_memory_state_i { class llama_kv_cache_unified_state : public llama_memory_state_i {
public: public:
// some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors // used for errors
llama_kv_cache_unified_state(llama_memory_status status); llama_kv_cache_unified_state(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache state
llama_kv_cache_unified_state( llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv); llama_kv_cache_unified * kv);
// used to create a state from a batch // used to create an update state
llama_kv_cache_unified_state(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo);
// used to create a decode state from a batch
llama_kv_cache_unified_state( llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_sbatch sbatch, llama_sbatch sbatch,
std::vector<uint32_t> heads, ubatch_heads heads,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state(); virtual ~llama_kv_cache_unified_state();
@ -253,16 +268,30 @@ public:
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private: private:
const llama_memory_status status; llama_memory_status status;
llama_kv_cache_unified * kv; llama_kv_cache_unified * kv;
llama_context * lctx;
//
// update state
//
bool do_shift = false;
defrag_info dinfo;
//
// batch processing state
//
llama_sbatch sbatch; llama_sbatch sbatch;
// the index of the next ubatch to process // the index of the next ubatch to process
size_t i_next = 0; size_t i_next = 0;
std::vector<uint32_t> heads; ubatch_heads heads;
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
// //

View File

@ -1 +0,0 @@
#include "llama-kv-cache.h"

View File

@ -80,6 +80,9 @@ public:
assert(isrc < pos.size()); assert(isrc < pos.size());
assert(idst < pos.size()); assert(idst < pos.size());
assert(pos[idst] == -1);
assert(pos[isrc] != -1);
pos [idst] = pos [isrc]; pos [idst] = pos [isrc];
shift[idst] = shift[isrc]; shift[idst] = shift[isrc];
seq [idst] = seq [isrc]; seq [idst] = seq [isrc];
@ -144,9 +147,10 @@ public:
assert(pos[i] != -1); assert(pos[i] != -1);
seq_pos_rm(i); seq_pos_rm(i);
seq[i].reset();
pos[i] = -1; pos[i] = -1;
seq[i].reset(); shift[i] = 0;
used.erase(i); used.erase(i);
} }
@ -164,6 +168,7 @@ public:
if (seq[i].none()) { if (seq[i].none()) {
pos[i] = -1; pos[i] = -1;
shift[i] = 0;
used.erase(i); used.erase(i);
@ -192,6 +197,7 @@ public:
seq[i].reset(); seq[i].reset();
pos[i] = -1; pos[i] = -1;
shift[i] = 0;
used.erase(i); used.erase(i);
@ -317,21 +323,20 @@ public:
pos[i] += d; pos[i] += d;
shift[i] += d; shift[i] += d;
seq_pos_add(i);
has_shift = true; has_shift = true;
if (pos[i] < 0) { if (pos[i] < 0) {
seq_pos_rm(i);
seq[i].reset(); seq[i].reset();
pos[i] = -1; pos[i] = -1;
shift[i] = 0;
used.erase(i); used.erase(i);
return true; return true;
} }
seq_pos_add(i);
return false; return false;
} }

View File

@ -1 +1,42 @@
#include "llama-memory.h" #include "llama-memory.h"
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
bool has_update = false;
switch (s0) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
has_update = true;
break;
}
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
break;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return s0;
}
}
switch (s1) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
has_update = true;
break;
}
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
break;
}
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
return s1;
}
}
// if either status has an update, then the combined status has an update
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
}

View File

@ -7,6 +7,9 @@
struct llama_ubatch; struct llama_ubatch;
class llama_io_write_i;
class llama_io_read_i;
struct llama_memory_params { struct llama_memory_params {
// kv cache // kv cache
ggml_type type_k; ggml_type type_k;
@ -16,32 +19,17 @@ struct llama_memory_params {
bool swa_full; bool swa_full;
}; };
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual ~llama_memory_i() = default;
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual bool get_can_edit() const = 0;
};
enum llama_memory_status { enum llama_memory_status {
LLAMA_MEMORY_STATUS_SUCCESS = 0, LLAMA_MEMORY_STATUS_SUCCESS = 0,
LLAMA_MEMORY_STATUS_NO_UPDATE,
LLAMA_MEMORY_STATUS_FAILED_PREPARE, LLAMA_MEMORY_STATUS_FAILED_PREPARE,
LLAMA_MEMORY_STATUS_FAILED_COMPUTE, LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
}; };
// helper function for combining the status of two memory states
// useful for implementing hybrid memory types (e.g. iSWA)
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
// the interface for managing the memory state during batch processing // the interface for managing the memory state during batch processing
// this interface is implemented per memory type. see: // this interface is implemented per memory type. see:
// - llama_kv_cache_unified_state // - llama_kv_cache_unified_state
@ -51,8 +39,7 @@ enum llama_memory_status {
// the only method that can mutate the memory and the memory state is llama_memory_i::apply() // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
// //
// TODO: rename to llama_memory_context_i ? // TODO: rename to llama_memory_context_i ?
class llama_memory_state_i { struct llama_memory_state_i {
public:
virtual ~llama_memory_state_i() = default; virtual ~llama_memory_state_i() = default;
// consume the current ubatch from the state and proceed to the next one // consume the current ubatch from the state and proceed to the next one
@ -69,8 +56,63 @@ public:
// get the current ubatch // get the current ubatch
virtual const llama_ubatch & get_ubatch() const = 0; virtual const llama_ubatch & get_ubatch() const = 0;
// get the status of the memory state // get the status of the memory state - used for error handling and checking if any updates would be applied
virtual llama_memory_status get_status() const = 0; virtual llama_memory_status get_status() const = 0;
}; };
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>; using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i {
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
//
// ops
//
// if data == true, the data buffers will also be cleared together with the metadata
virtual void clear(bool data) = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
// TODO: temporary until the llama_kv_cache is removed from the public API
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
};

View File

@ -401,7 +401,7 @@ struct llama_mmap::impl {
} }
} }
#else #else
throw std::runtime_error("PrefetchVirtualMemory unavailable"); LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
#endif #endif
} }
} }

View File

@ -288,9 +288,10 @@ namespace GGUFMeta {
template<typename T> template<typename T>
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) { bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
const int kid = gguf_find_key(meta.get(), key.c_str()); const gguf_context * ctx = meta.get();
const int kid = gguf_find_key(ctx, key.c_str());
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
if (required) { if (required) {
throw std::runtime_error(format("array key not found in model: %s", key.c_str())); throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
} }
@ -298,28 +299,40 @@ namespace GGUFMeta {
} }
struct GGUFMeta::ArrayInfo arr_info = struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
switch (arr_info.gt) { switch (arr_info.gt) {
case GGUF_TYPE_UINT32: case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) || case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break; (std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break; case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
default: default:
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
} }
if constexpr (std::is_same<T, std::string>::value) {
const size_t n_items = gguf_get_arr_n(ctx, kid);
result.clear();
for (size_t i = 0; i < n_items; i++) {
const T value = gguf_get_arr_str(ctx, kid, i);
result.emplace_back(value);
}
} else {
result.resize(arr_info.length); result.resize(arr_info.length);
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
}
return true; return true;
} }
template<typename T, size_t N_MAX> template<typename T, size_t N_MAX>
bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) { bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
const int kid = gguf_find_key(meta.get(), key.c_str()); const gguf_context * ctx = meta.get();
const int kid = gguf_find_key(ctx, key.c_str());
if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
if (required) { if (required) {
throw std::runtime_error(format("array key not found in model: %s", key.c_str())); throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
} }
@ -327,22 +340,32 @@ namespace GGUFMeta {
} }
struct GGUFMeta::ArrayInfo arr_info = struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
switch (arr_info.gt) { switch (arr_info.gt) {
case GGUF_TYPE_UINT32: case GGUF_TYPE_UINT32:
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) || case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) ||
(std::is_same<T, uint32_t>::value)); break; (std::is_same<T, uint32_t>::value)); break;
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break; case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same<T, std::string>::value)); break;
default: default:
throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
} }
if (arr_info.length > N_MAX) { if (arr_info.length > N_MAX) {
throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
} }
if constexpr (std::is_same<T, std::string>::value) {
const size_t n_items = gguf_get_arr_n(ctx, kid);
for (size_t i = 0; i < n_items; i++) {
const T value = gguf_get_arr_str(ctx, kid, i);
result[i] = value;
}
} else {
std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
}
return true; return true;
} }
@ -352,6 +375,8 @@ namespace GGUFMeta {
return get_arr(llm_kv(kid), result, required); return get_arr(llm_kv(kid), result, required);
} }
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
template<typename T> template<typename T>
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
auto it = kv_overrides.find(key); auto it = kv_overrides.find(key);

View File

@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
uint32_t n_vocab = 0; uint32_t n_vocab = 0;
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
// for classifier models
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
if (!classifier_labels.empty()) {
hparams.n_cls_out = classifier_labels.size();
}
// arch-specific KVs // arch-specific KVs
switch (arch) { switch (arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 3: case 3:
@ -956,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case 46: type = LLM_TYPE_27B; break; case 46: type = LLM_TYPE_27B; break;
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break; } break;
case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3:
{ {
@ -976,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
hparams.f_attention_scale = type == LLM_TYPE_27B hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k)); : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@ -4356,6 +4367,15 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
if (!classifier_labels.empty()) {
LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out);
size_t i = 0;
for (auto label : classifier_labels) {
LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str());
}
}
} }
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str());
@ -8484,14 +8504,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
switch (model.type) {
case LLM_TYPE_2B:
case LLM_TYPE_9B:
case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
default: GGML_ABORT("fatal error");
};
cb(Qcur, "Qcur_scaled", il);
cur = build_attn(inp_attn, gf, cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
@ -8632,9 +8645,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
cur = build_attn(inp_attn, gf, cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
} }
cur = build_norm(cur, cur = build_norm(cur,
@ -13600,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
return model->hparams.n_swa; return model->hparams.n_swa;
} }
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
return model->hparams.n_cls_out;
}
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
if (i < model->classifier_labels.size()) {
return model->classifier_labels[i].c_str();
}
return nullptr;
}
// deprecated // deprecated
int32_t llama_n_ctx_train(const llama_model * model) { int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model); return llama_model_n_ctx_train(model);
@ -13760,7 +13788,7 @@ uint64_t llama_model_size(const llama_model * model) {
} }
const char * llama_model_chat_template(const llama_model * model, const char * name) { const char * llama_model_chat_template(const llama_model * model, const char * name) {
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key); const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) { if (it == model->gguf_kv.end()) {

View File

@ -329,6 +329,9 @@ struct llama_model {
llama_hparams hparams = {}; llama_hparams hparams = {};
llama_vocab vocab; llama_vocab vocab;
// for classifier models
std::vector<std::string> classifier_labels;
struct ggml_tensor * tok_embd = nullptr; struct ggml_tensor * tok_embd = nullptr;
struct ggml_tensor * type_embd = nullptr; struct ggml_tensor * type_embd = nullptr;
struct ggml_tensor * pos_embd = nullptr; struct ggml_tensor * pos_embd = nullptr;

View File

@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
std::string model_name; std::string model_name;
std::string tokenizer_pre; std::string tokenizer_pre;
std::string general_arch;
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false); ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
// model name to lowercase // model name to lowercase
std::transform(model_name.begin(), model_name.end(), model_name.begin(), std::transform(model_name.begin(), model_name.end(), model_name.begin(),
@ -2091,9 +2093,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} }
); );
// set attributes by model/tokenizer name // set attributes by model/tokenizer/architecture name
if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) { if (false
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|| _contains_any(general_arch, {"nomic-bert-moe"})
) {
if (token_to_id.count("<mask>") == 0) {
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
} else {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true); _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
}
} else if (_contains_any(model_name, {"phi-3", "phi3"})) { } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : cache_special_tokens) { for (auto id : cache_special_tokens) {
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true); _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);

View File

@ -61,7 +61,10 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampler; struct llama_sampler;
struct llama_kv_cache;
typedef struct llama_memory_i * llama_memory_t;
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
typedef int32_t llama_pos; typedef int32_t llama_pos;
typedef int32_t llama_token; typedef int32_t llama_token;
@ -493,9 +496,11 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@ -509,6 +514,13 @@ extern "C" {
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
// Returns the number of classifier outputs (only valid for classifier models)
// Undefined behavior for non-classifier models
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@ -609,7 +621,81 @@ extern "C" {
int32_t il_end); int32_t il_end);
// //
// KV cache // Memory
//
// Clear the memory contents
// If data == true, the data buffers will also be cleared together with the metadata
LLAMA_API void llama_memory_clear(
llama_memory_t mem,
bool data);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
// Copy all tokens that belong to the specified sequence to another sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
// Returns the smallest position present in the memory for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id);
// Returns the largest position present in the memory for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id);
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
//
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
// //
// Returns the number of tokens in the KV cache (slow, use only for debug) // Returns the number of tokens in the KV cache (slow, use only for debug)
@ -622,86 +708,95 @@ extern "C" {
"Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
// Clear the KV cache - both cell info is erased and KV data is zeroed // Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear( DEPRECATED(LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx); struct llama_context * ctx),
"Use llama_memory_clear() instead");
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence // seq_id < 0 : match any sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_self_seq_rm( DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1),
"Use llama_memory_seq_rm() instead");
// Copy all tokens that belong to the specified sequence to another sequence // Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_cp( DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1),
"Use llama_memory_seq_cp() instead");
// Removes all tokens that do not belong to the specified sequence // Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_self_seq_keep( DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id),
"Use llama_memory_seq_keep() instead");
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode() // - lazily on next llama_decode()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_add( DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta); llama_pos delta),
"Use llama_memory_seq_add() instead");
// Integer division of the positions by factor of `d > 1` // Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode() // - lazily on next llama_decode()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_div( DEPRECATED(void llama_kv_self_seq_div(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d); int d),
"Use llama_memory_seq_div() instead");
// Returns the smallest position present in the KV cache for the specified sequence // Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches // This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty // Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_min( DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id),
"Use llama_memory_seq_pos_min() instead");
// Returns the largest position present in the KV cache for the specified sequence // Returns the largest position present in the KV cache for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty // Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_max( DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id),
"Use llama_memory_seq_pos_max() instead");
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:
// - lazily on next llama_decode() // - lazily on next llama_decode()
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
// Check if the context supports KV cache shifting // Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
"use llama_memory_can_shift() instead");
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()"); "simply remove this call, updates are applied lazily on the next llama_decode()");
// //
@ -709,7 +804,7 @@ extern "C" {
// //
// Returns the *actual* size in bytes of the state // Returns the *actual* size in bytes of the state
// (logits, embedding and kv_cache) // (logits, embedding and memory)
// Only use when saving the state, not when restoring it, otherwise the size may be too small. // Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@ -765,12 +860,12 @@ extern "C" {
size_t n_token_count), size_t n_token_count),
"use llama_state_save_file instead"); "use llama_state_save_file instead");
// Get the exact size needed to copy the KV cache of a single sequence // Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size( LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
// Copy the KV cache of a single sequence into the specified buffer // Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data( LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx, struct llama_context * ctx,
uint8_t * dst, uint8_t * dst,
@ -836,16 +931,16 @@ extern "C" {
// For encode-decoder contexts, processes the batch using the encoder. // For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers. // Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the memory state is restored to the state before this call
LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
// Process a batch of tokens. // Process a batch of tokens.
// Requires KV cache. // Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder. // For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the KV cache state is restored to the state before this call // Upon non-zero return values, the memory state is restored to the state before this call
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted // 2 - aborted
@ -916,7 +1011,7 @@ extern "C" {
// Get the embeddings for a sequence id // Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
// otherwise: float[n_embd] (1-dimensional) // otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);