diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index da190e33..d5354638 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -16,7 +16,6 @@ if (WHISPER_SDL2) llama-hparams.cpp llama-impl.cpp llama-io.cpp - llama-kv-cache.cpp llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp llama-kv-cache-recurrent.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index c0590e10..43fa60a8 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -200,7 +200,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -1707,8 +1706,14 @@ static const std::map LLM_TENSOR_INFOS = { LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) - : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; } std::string LLM_TN_IMPL::str() const { diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 930cb4ec..f3825528 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -196,7 +196,6 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, - LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 4ab57438..b130b484 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2,9 +2,9 @@ #include "llama-impl.h" #include "llama-io.h" +#include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" -#include "llama-kv-cache.h" #include #include @@ -123,7 +123,7 @@ llama_context::llama_context( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - if (!params.swa_full && cparams.n_seq_max > 1) { + if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); } @@ -277,10 +277,9 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache - llama_kv_cache * kv_self = static_cast(memory.get()); - const auto kv_state = kv_self->init_full(); - if (!kv_state) { + const auto mstate = memory->init_full(); + if (!mstate) { throw std::runtime_error("failed to initialize KV cache"); } @@ -288,7 +287,7 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -299,7 +298,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, kv_state.get()); + auto * gf = graph_reserve(1, 1, 1, mstate.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -310,7 +309,7 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -419,40 +418,68 @@ uint32_t llama_context::n_threads_batch() const { return cparams.n_threads_batch; } -llama_kv_cache * llama_context::get_kv_self() { - llama_kv_cache * kv_self = static_cast(memory.get()); - return kv_self; +llama_memory_t llama_context::get_memory() const { + return memory.get(); } -const llama_kv_cache * llama_context::get_kv_self() const { - llama_kv_cache * kv_self = static_cast(memory.get()); - return kv_self; +// deprecated +void llama_context::kv_self_defrag_sched() { + if (!memory) { + return; + } + + memory_force_optimize = true; } -bool llama_context::kv_self_update() { +// deprecated +bool llama_context::kv_self_update(bool optimize) { if (!memory) { return false; } - llama_kv_cache * kv_self = static_cast(memory.get()); + { + // TODO: remove in the future + optimize |= memory_force_optimize; + memory_force_optimize = false; - if (!kv_self->update(*this)) { - // no updates have been performed - return false; + const auto mstate = memory->init_update(this, optimize); + switch (mstate->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + // noop + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + // no updates need to be performed + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); + return false; + } + } + + if (!mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); + } } - // if the KV cache did any computation, we have to reserve a new worst-case graph - const auto kv_state = kv_self->init_full(); - if (!kv_state) { - throw std::runtime_error("failed to initialize KV cache"); - } + // if the memory module did any computation, we have to reserve a new worst-case graph + { + const auto mstate = memory->init_full(); + if (!mstate) { + throw std::runtime_error("failed to initialize memory state"); + } - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); - if (!gf) { - LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); + } } return true; @@ -814,16 +841,17 @@ int llama_context::encode(llama_batch & inp_batch) { } break; case LLAMA_POOLING_TYPE_RANK: { - // extract the rerank score - a single float per sequence + // extract the rerank score - n_cls_out floats per sequence auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { const llama_seq_id seq_id = ubatch.seq_id[s][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; } - embd_seq_out[seq_id].resize(1); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: @@ -880,10 +908,8 @@ int llama_context::decode(llama_batch & inp_batch) { } } - llama_kv_cache * kv_self = static_cast(memory.get()); - // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1); const llama_batch & batch = batch_allocr.batch; @@ -940,42 +966,49 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } + bool did_optimize = false; + // handle any pending defrags/shifts - kv_self_update(); + kv_self_update(false); - llama_memory_state_ptr kv_state; - - bool did_defrag = false; + llama_memory_state_ptr mstate; while (true) { - kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); - if (!kv_state) { + mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!mstate) { return -2; } - switch (kv_state->get_status()) { + switch (mstate->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); + + return -2; + } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { - if (!did_defrag) { - did_defrag = true; + if (!did_optimize) { + did_optimize = true; - kv_self->defrag_sched(-1.0f); - if (kv_self_update()) { - LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens); + if (kv_self_update(true)) { + LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); continue; } } - LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); + LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { + LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); + return -2; } } @@ -992,7 +1025,7 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_prev = 0; do { - const auto & ubatch = kv_state->get_ubatch(); + const auto & ubatch = mstate->get_ubatch(); // count the outputs in this u_batch { @@ -1015,11 +1048,14 @@ int llama_context::decode(llama_batch & inp_batch) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache - llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits::max() }; + llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES]; + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + pos_min[s] = std::numeric_limits::max(); + } for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { const auto & seq_id = ubatch.seq_id[i][0]; @@ -1034,7 +1070,7 @@ int llama_context::decode(llama_batch & inp_batch) { LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); - llama_kv_self_seq_rm(this, s, pos_min[s], -1); + memory->seq_rm(s, pos_min[s], -1); } switch (status) { @@ -1128,7 +1164,7 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } while (kv_state->next()); + } while (mstate->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1137,7 +1173,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = kv_state->out_ids(); + auto & out_ids = mstate->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1189,11 +1225,6 @@ int llama_context::decode(llama_batch & inp_batch) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - // decide if we need to defrag the kv cache - if (cparams.defrag_thold > 0.0f) { - kv_self->defrag_sched(cparams.defrag_thold); - } - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. ggml_backend_sched_reset(sched.get()); @@ -1810,11 +1841,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } - llama_kv_cache * kv_self = static_cast(memory.get()); - - if (kv_self != nullptr) { + if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); - kv_self->state_write(io); + memory->state_write(io); } return io.n_bytes(); @@ -1901,9 +1930,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { if (memory) { LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->state_read(io); + memory->state_read(io); } return io.n_bytes(); @@ -1913,9 +1940,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s GGML_UNUSED(seq_id); if (memory) { - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->state_write(io, seq_id); + memory->state_write(io, seq_id); } return io.n_bytes(); @@ -1925,9 +1950,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq GGML_UNUSED(seq_id); if (memory) { - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->state_read(io, seq_id); + memory->state_read(io, seq_id); } return io.n_bytes(); @@ -2032,9 +2055,7 @@ void llama_context::opt_epoch_iter( const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->clear(); + memory->clear(true); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -2057,8 +2078,8 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); - if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); break; } @@ -2071,17 +2092,17 @@ void llama_context::opt_epoch_iter( uint32_t pos_batch = 0; do { - const auto & ubatch = kv_state->get_ubatch(); + const auto & ubatch = mstate->get_ubatch(); n_outputs = ubatch.n_tokens; - if (!kv_state->apply()) { + if (!mstate->apply()) { LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); struct ggml_context * ctx_compute_opt; { @@ -2116,7 +2137,7 @@ void llama_context::opt_epoch_iter( ggml_free(ctx_compute_opt); pos_batch += ubatch.n_tokens; - } while (kv_state->next()); + } while (mstate->next()); } } @@ -2277,13 +2298,14 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } +// deprecated llama_kv_cache * llama_get_kv_self(llama_context * ctx) { - return ctx->get_kv_self(); + return dynamic_cast(ctx->get_memory()); } // deprecated void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(); + ctx->kv_self_update(false); } enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { @@ -2398,13 +2420,118 @@ int32_t llama_apply_adapter_cvec( return res ? 0 : -1; } +// +// memory +// + +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + return ctx->get_memory(); +} + +void llama_memory_clear(llama_memory_t mem, bool data) { + if (!mem) { + return; + } + + mem->clear(data); +} + +bool llama_memory_seq_rm( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + if (!mem) { + return true; + } + + return mem->seq_rm(seq_id, p0, p1); +} + +void llama_memory_seq_cp( + llama_memory_t mem, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (!mem) { + return; + } + + mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_seq_keep( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return; + } + + mem->seq_keep(seq_id); +} + +void llama_memory_seq_add( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + if (!mem) { + return; + } + + mem->seq_add(seq_id, p0, p1, delta); +} + +void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (!mem) { + return; + } + + mem->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return -1; + } + + return mem->seq_pos_min(seq_id); +} + +llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return -1; + } + + return mem->seq_pos_max(seq_id); +} + +bool llama_memory_can_shift(llama_memory_t mem) { + if (!mem) { + return false; + } + + return mem->get_can_shift(); +} + // // kv cache // // deprecated int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + const auto * kv = llama_get_memory(ctx); if (!kv) { return 0; } @@ -2426,7 +2553,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) { // deprecated // note: this is the same as above - will be removed anyway, so it's ok int32_t llama_kv_self_used_cells(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + const auto * kv = llama_get_memory(ctx); if (!kv) { return 0; } @@ -2445,115 +2572,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) { return res; } +// deprecated void llama_kv_self_clear(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->clear(); + llama_memory_clear(kv, true); } +// deprecated bool llama_kv_self_seq_rm( llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return true; } - return kv->seq_rm(seq_id, p0, p1); + return llama_memory_seq_rm(kv, seq_id, p0, p1); } +// deprecated void llama_kv_self_seq_cp( llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); + llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); } +// deprecated void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_keep(seq_id); + llama_memory_seq_keep(kv, seq_id); } +// deprecated void llama_kv_self_seq_add( llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_add(seq_id, p0, p1, delta); + llama_memory_seq_add(kv, seq_id, p0, p1, delta); } +// deprecated void llama_kv_self_seq_div( llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_div(seq_id, p0, p1, d); + llama_memory_seq_div(kv, seq_id, p0, p1, d); } +// deprecated llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return -1; } - return kv->seq_pos_min(seq_id); + return llama_memory_seq_pos_min(kv, seq_id); } +// deprecated llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return -1; } - return kv->seq_pos_max(seq_id); + return llama_memory_seq_pos_max(kv, seq_id); } // deprecated void llama_kv_self_defrag(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); - if (!kv) { - return; - } - // force defrag - kv->defrag_sched(-1.0f); + ctx->kv_self_defrag_sched(); } +// deprecated bool llama_kv_self_can_shift(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return false; } - return kv->get_can_shift(); + return llama_memory_can_shift(kv); } // llama state API diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 3b880286..2e0da8c8 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -13,13 +13,12 @@ #include struct llama_model; -struct llama_kv_cache; class llama_io_read_i; class llama_io_write_i; -class llama_memory_i; -class llama_memory_state_i; +struct llama_memory_i; +struct llama_memory_state_i; struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs @@ -47,12 +46,12 @@ struct llama_context { uint32_t n_threads() const; uint32_t n_threads_batch() const; - llama_kv_cache * get_kv_self(); - const llama_kv_cache * get_kv_self() const; + llama_memory_t get_memory() const; // return true of the KV cache was updated // TODO: remove - bool kv_self_update(); + bool kv_self_update(bool optimize); + void kv_self_defrag_sched(); enum llama_pooling_type pooling_type() const; @@ -231,6 +230,9 @@ private: std::unique_ptr memory; + // TODO: temporary, until the llama_kv_self_defrag() API is removed + bool memory_force_optimize = false; + // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 727e119e..27c9ab74 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -659,6 +659,20 @@ ggml_tensor * llm_graph_context::build_ffn( cur = ggml_mul(ctx0, x0, x1); cb(cur, "ffn_mul", il); } break; + case LLM_FFN_GEGLU: + { + // Split into two equal parts + int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_gelu(ctx0, x0); + cb(x0, "ffn_gelu", il); + + cur = ggml_mul(ctx0, x0, x1); + cb(cur, "ffn_geglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { @@ -769,9 +783,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { - // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) - ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); - repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] + // repeat cur to [n_embd, n_expert_used, n_tokens] + ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); cur = ggml_mul(ctx0, repeated, weights); cb(cur, "ffn_moe_weighted", il); } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index d1c5dd1b..28da6a52 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -17,7 +17,7 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_state_i; +struct llama_memory_state_i; class llama_kv_cache_unified_state; class llama_kv_cache_unified_iswa_state; @@ -36,6 +36,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU, LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, + LLM_FFN_GEGLU, }; enum llm_ffn_gate_type { diff --git a/examples/talk-llama/llama-kv-cache-recurrent.cpp b/examples/talk-llama/llama-kv-cache-recurrent.cpp index 641eab2f..f5c6dcd6 100644 --- a/examples/talk-llama/llama-kv-cache-recurrent.cpp +++ b/examples/talk-llama/llama-kv-cache-recurrent.cpp @@ -1,6 +1,7 @@ #include "llama-kv-cache-recurrent.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-batch.h" #include "llama-model.h" @@ -116,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( } } -void llama_kv_cache_recurrent::clear() { +void llama_kv_cache_recurrent::clear(bool data) { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); cells[i].src = -1; cells[i].tail = -1; } + head = 0; used = 0; - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } } } @@ -386,6 +390,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); } +llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(lctx); + GGML_UNUSED(optimize); + + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); +} + bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { // simply remember the full state because it is very small for this type of cache // TODO: optimize @@ -419,17 +430,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector & ubatche return success; } -bool llama_kv_cache_recurrent::update(llama_context & lctx) { - GGML_UNUSED(lctx); - // noop - return false; -} - -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; @@ -726,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq if (!res) { if (seq_id == -1) { - clear(); + clear(true); } else { seq_rm(seq_id, -1, -1); } @@ -883,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce return false; } - clear(); + clear(true); for (uint32_t i = 0; i < cell_count; ++i) { kv_cell & cell = cells[i]; diff --git a/examples/talk-llama/llama-kv-cache-recurrent.h b/examples/talk-llama/llama-kv-cache-recurrent.h index a178ae85..d1da1225 100644 --- a/examples/talk-llama/llama-kv-cache-recurrent.h +++ b/examples/talk-llama/llama-kv-cache-recurrent.h @@ -2,7 +2,7 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" +#include "llama-memory.h" #include #include @@ -13,7 +13,7 @@ // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i // see the implementation of llama_kv_cache_unified_state_i for an example how to do it -class llama_kv_cache_recurrent : public llama_kv_cache { +class llama_kv_cache_recurrent : public llama_memory_i { public: llama_kv_cache_recurrent( const llama_model & model, @@ -29,7 +29,17 @@ public: // llama_memory_i // - void clear() override; + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -40,22 +50,6 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - bool prepare(const std::vector & ubatches); // find a contiguous slot of kv cells and emplace the ubatch there diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp index 0eb04563..28d18265 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp @@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( hparams.n_swa, hparams.swa_type); } -void llama_kv_cache_unified_iswa::clear() { - kv_base->clear(); - kv_swa ->clear(); +void llama_kv_cache_unified_iswa::clear(bool data) { + kv_base->clear(data); + kv_swa ->clear(data); } bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch assert(heads_base.size() == heads_swa.size()); - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique( this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } -bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { - bool res = false; - - res = res | kv_base->update(lctx); - res = res | kv_swa ->update(lctx); - - return res; -} - -void llama_kv_cache_unified_iswa::defrag_sched(float thold) { - kv_base->defrag_sched(thold); - kv_swa ->defrag_sched(thold); +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_unified_iswa::get_can_shift() const { @@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, - llama_kv_cache_unified_iswa * kv) : status(status) { - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base())); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ())); + llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_full(); + state_swa = kv->get_swa ()->init_full(); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_update(lctx, optimize); + state_swa = kv->get_swa ()->init_update(lctx, optimize); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); } llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv, llama_sbatch sbatch, std::vector heads_base, std::vector heads_swa, std::vector ubatches) - : status(status), - sbatch(std::move(sbatch)), - ubatches(std::move(ubatches)) { - // note: here we copy the ubatches. not sure if this is ideal - state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches)); - state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); - } + : status(LLAMA_MEMORY_STATUS_SUCCESS), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; @@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; } const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return state_base.get(); + return static_cast(state_base.get()); } const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return state_swa.get(); + return static_cast(state_swa.get()); } diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.h b/examples/talk-llama/llama-kv-cache-unified-iswa.h index 8b067da0..3dbf33ed 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.h +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.h @@ -11,7 +11,7 @@ // utilizes two instances of llama_kv_cache_unified // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -class llama_kv_cache_unified_iswa : public llama_kv_cache { +class llama_kv_cache_unified_iswa : public llama_memory_i { public: llama_kv_cache_unified_iswa( const llama_model & model, @@ -31,7 +31,19 @@ public: // llama_memory_i // - void clear() override; + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -42,24 +54,6 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - - bool get_can_shift() const override; - // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -86,12 +80,16 @@ public: // used to create a full-cache state llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv); + // used to create an update state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize); + // used to create a state from a batch llama_kv_cache_unified_iswa_state( - llama_memory_status status, llama_kv_cache_unified_iswa * kv, llama_sbatch sbatch, std::vector heads_base, @@ -120,7 +118,7 @@ public: const llama_kv_cache_unified_state * get_swa() const; private: - const llama_memory_status status; + llama_memory_status status; //llama_kv_cache_unified_iswa * kv; @@ -131,6 +129,6 @@ private: std::vector ubatches; - std::unique_ptr state_base; - std::unique_ptr state_swa; + llama_memory_state_ptr state_base; + llama_memory_state_ptr state_swa; }; diff --git a/examples/talk-llama/llama-kv-cache-unified.cpp b/examples/talk-llama/llama-kv-cache-unified.cpp index a8171547..3566d5fd 100644 --- a/examples/talk-llama/llama-kv-cache-unified.cpp +++ b/examples/talk-llama/llama-kv-cache-unified.cpp @@ -1,6 +1,7 @@ #include "llama-kv-cache-unified.h" #include "llama-impl.h" +#include "llama-io.h" #include "llama-model.h" #include "llama-context.h" @@ -128,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified( } } -void llama_kv_cache_unified::clear() { +void llama_kv_cache_unified::clear(bool data) { cells.reset(); head = 0; - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } } } @@ -149,12 +152,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + if (seq_id >= 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } else { + // match any sequence + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + cells.rm(i); - if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { if (new_head == cells.size()) { new_head = i; } @@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, + return std::make_unique( this, std::move(sbatch), std::move(heads), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_unified::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } -std::vector llama_kv_cache_unified::prepare(const std::vector & ubatches) { - std::vector res; +llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { + bool do_shift = get_has_shift(); + + defrag_info dinfo; + + // see if we need to defrag + { + bool do_defrag = optimize; + + const auto thold = lctx->get_cparams().defrag_thold; + + if (!do_defrag && thold > 0.0f) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } + } + + if (do_defrag) { + dinfo = defrag_prepare(lctx->graph_max_nodes()); + } + } + + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); +} + +llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::ubatch_heads res; struct state { uint32_t head_old; // old position of the head, before placing the ubatch @@ -359,12 +410,12 @@ std::vector llama_kv_cache_unified::prepare(const std::vectorget_sched(); - if (cells.get_has_shift()) { + if (do_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } @@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(sched); - auto * gf = lctx.graph_init(); + auto * gf = lctx->graph_init(); - auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); + auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); if (!res) { LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); return updated; @@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { res->set_inputs(nullptr); - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); return updated; } @@ -401,56 +452,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { cells.reset_shift(); } - if (do_defrag) { + if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched); + // apply moves: + { + const auto n_kv = dinfo.ids.size(); - auto * gf = lctx.graph_init(); + for (uint32_t i = 0; i < n_kv; ++i) { + assert(dinfo.ids[i] <= n_kv); - auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); - return updated; + if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { + continue; + } + + cells.mv(i, dinfo.ids[i]); } - if (!ggml_backend_sched_alloc_graph(sched, gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); - return updated; - } - - res->set_inputs(nullptr); - - if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) { - LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); - return updated; - } - - updated = true; + // reset the head so we can find the first free slot during the next ubatch + head = 0; } - do_defrag = false; + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; } return updated; } -void llama_kv_cache_unified::defrag_sched(float thold) { - const auto n_kv = cells.used_max_p1(); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; @@ -597,6 +647,10 @@ uint32_t llama_kv_cache_unified::get_size() const { return cells.size(); } +bool llama_kv_cache_unified::get_has_shift() const { + return cells.get_has_shift(); +} + uint32_t llama_kv_cache_unified::get_n_kv() const { return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); } @@ -890,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - //GGML_ASSERT(kv_self->size == n_ctx); - auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); ggml_set_input(inp->k_shift); for (const auto & layer : layers) { @@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( } llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const { auto res = std::make_unique(); - const auto & ids = defrag_info.ids; + const auto & ids = dinfo.ids; #if 0 // CPU defrag @@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( return res; } -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { +llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; + defrag_info res; + auto & ids = res.ids; - ids.clear(); ids.resize(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { @@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // this cell goes to (i0 + nf) ids[i1] = i0 + nf; - // move the cell meta data - cells.mv(i1, i0 + nf); - - head = n_used; - if (!cont) { n_moves++; cont = true; @@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { } if (n_moves == 0) { - return false; + return {}; } LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); - return true; + return res; } bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { @@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i if (!res) { if (seq_id == -1) { - clear(); + clear(true); } else { seq_rm(seq_id, -1, -1); } @@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return false; } - clear(); + clear(true); for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; @@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv) : status(status), kv(kv) { - n_kv = kv->get_size(); - head = 0; - } + llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { + n_kv = kv->get_size(); + head = 0; +} llama_kv_cache_unified_state::llama_kv_cache_unified_state( - llama_memory_status status, - llama_kv_cache_unified * kv, - llama_sbatch sbatch, - std::vector heads, - std::vector ubatches) - : status(status), - kv(kv), - sbatch(std::move(sbatch)), - heads(std::move(heads)), - ubatches(std::move(ubatches)) { + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { + if (!do_shift && dinfo.empty()) { + status = LLAMA_MEMORY_STATUS_NO_UPDATE; } +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + llama_kv_cache_unified::ubatch_heads heads, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { +} llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; @@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() { bool llama_kv_cache_unified_state::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + // no ubatches -> this is a KV cache update + if (ubatches.empty()) { + kv->update(lctx, do_shift, dinfo); + + return true; + } + kv->apply_ubatch(heads[i_next], ubatches[i_next]); n_kv = kv->get_n_kv(); diff --git a/examples/talk-llama/llama-kv-cache-unified.h b/examples/talk-llama/llama-kv-cache-unified.h index 1f1d44b9..49f410ef 100644 --- a/examples/talk-llama/llama-kv-cache-unified.h +++ b/examples/talk-llama/llama-kv-cache-unified.h @@ -2,8 +2,8 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" #include "llama-kv-cells.h" +#include "llama-memory.h" #include #include @@ -17,13 +17,26 @@ struct llama_context; // llama_kv_cache_unified // -class llama_kv_cache_unified : public llama_kv_cache { +class llama_kv_cache_unified : public llama_memory_i { public: static uint32_t get_padding(const llama_cparams & cparams); // this callback is used to filter out layers that should not be included in the cache using layer_filter_cb = std::function; + using ubatch_heads = std::vector; + + struct defrag_info { + bool empty() const { + return ids.empty(); + } + + // contains information about which cell moves where: + // - cell i moves to ids[i] + // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved + std::vector ids; + }; + llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -43,7 +56,19 @@ public: // llama_memory_i // - void clear() override; + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -54,24 +79,6 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - - bool get_can_shift() const override; - // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -83,6 +90,8 @@ public: uint32_t get_size() const; + bool get_has_shift() const; + // // graph_build API // @@ -103,7 +112,9 @@ public: // find places for the provided ubatches in the cache, returns the head locations // return empty vector on failure - std::vector prepare(const std::vector & ubatches); + ubatch_heads prepare(const std::vector & ubatches); + + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); // return the cell position where we can insert the ubatch // return -1 on failure to find a contiguous slot of kv cells @@ -133,8 +144,7 @@ private: ggml_tensor * v; }; - bool do_defrag = false; - bool v_trans = true; // the value tensor is transposed + bool v_trans = true; // the value tensor is transposed // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method @@ -160,13 +170,8 @@ private: // model layer id -> KV cache layer id std::unordered_map map_layer_ids; - // defrag - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); + // return non-empty vector if cells have been moved + defrag_info defrag_prepare(int32_t n_max_nodes) const; size_t total_size() const; @@ -192,7 +197,8 @@ private: llm_graph_result_ptr build_graph_defrag( const llama_cparams & cparams, ggml_context * ctx, - ggml_cgraph * gf) const; + ggml_cgraph * gf, + const defrag_info & dinfo) const; void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -203,20 +209,29 @@ private: class llama_kv_cache_unified_state : public llama_memory_state_i { public: + // some shorthands + using ubatch_heads = llama_kv_cache_unified::ubatch_heads; + using defrag_info = llama_kv_cache_unified::defrag_info; + // used for errors llama_kv_cache_unified_state(llama_memory_status status); // used to create a full-cache state llama_kv_cache_unified_state( - llama_memory_status status, llama_kv_cache_unified * kv); - // used to create a state from a batch + // used to create an update state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo); + + // used to create a decode state from a batch llama_kv_cache_unified_state( - llama_memory_status status, llama_kv_cache_unified * kv, llama_sbatch sbatch, - std::vector heads, + ubatch_heads heads, std::vector ubatches); virtual ~llama_kv_cache_unified_state(); @@ -253,16 +268,30 @@ public: void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: - const llama_memory_status status; + llama_memory_status status; llama_kv_cache_unified * kv; + llama_context * lctx; + + // + // update state + // + + bool do_shift = false; + + defrag_info dinfo; + + // + // batch processing state + // llama_sbatch sbatch; // the index of the next ubatch to process size_t i_next = 0; - std::vector heads; + ubatch_heads heads; + std::vector ubatches; // diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp deleted file mode 100644 index aefd23e3..00000000 --- a/examples/talk-llama/llama-kv-cache.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "llama-kv-cache.h" diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index 9e2c4d92..acf30aeb 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -80,6 +80,9 @@ public: assert(isrc < pos.size()); assert(idst < pos.size()); + assert(pos[idst] == -1); + assert(pos[isrc] != -1); + pos [idst] = pos [isrc]; shift[idst] = shift[isrc]; seq [idst] = seq [isrc]; @@ -144,9 +147,10 @@ public: assert(pos[i] != -1); seq_pos_rm(i); + seq[i].reset(); pos[i] = -1; - seq[i].reset(); + shift[i] = 0; used.erase(i); } @@ -164,6 +168,7 @@ public: if (seq[i].none()) { pos[i] = -1; + shift[i] = 0; used.erase(i); @@ -192,6 +197,7 @@ public: seq[i].reset(); pos[i] = -1; + shift[i] = 0; used.erase(i); @@ -317,21 +323,20 @@ public: pos[i] += d; shift[i] += d; - seq_pos_add(i); - has_shift = true; if (pos[i] < 0) { - seq_pos_rm(i); - seq[i].reset(); pos[i] = -1; + shift[i] = 0; used.erase(i); return true; } + seq_pos_add(i); + return false; } diff --git a/examples/talk-llama/llama-memory.cpp b/examples/talk-llama/llama-memory.cpp index 10173253..f1107672 100644 --- a/examples/talk-llama/llama-memory.cpp +++ b/examples/talk-llama/llama-memory.cpp @@ -1 +1,42 @@ #include "llama-memory.h" + +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) { + bool has_update = false; + + switch (s0) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s0; + } + } + + switch (s1) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s1; + } + } + + // if either status has an update, then the combined status has an update + return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; +} diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index b3799d66..991aae78 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -7,6 +7,9 @@ struct llama_ubatch; +class llama_io_write_i; +class llama_io_read_i; + struct llama_memory_params { // kv cache ggml_type type_k; @@ -16,32 +19,17 @@ struct llama_memory_params { bool swa_full; }; -// general concept of LLM memory -// the KV cache is a type of LLM memory, but there can be other types -class llama_memory_i { -public: - virtual ~llama_memory_i() = default; - - virtual void clear() = 0; - - virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; - virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; - virtual void seq_keep(llama_seq_id seq_id) = 0; - virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; - virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; - - virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; - virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; - - virtual bool get_can_edit() const = 0; -}; - enum llama_memory_status { LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_NO_UPDATE, LLAMA_MEMORY_STATUS_FAILED_PREPARE, LLAMA_MEMORY_STATUS_FAILED_COMPUTE, }; +// helper function for combining the status of two memory states +// useful for implementing hybrid memory types (e.g. iSWA) +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); + // the interface for managing the memory state during batch processing // this interface is implemented per memory type. see: // - llama_kv_cache_unified_state @@ -51,8 +39,7 @@ enum llama_memory_status { // the only method that can mutate the memory and the memory state is llama_memory_i::apply() // // TODO: rename to llama_memory_context_i ? -class llama_memory_state_i { -public: +struct llama_memory_state_i { virtual ~llama_memory_state_i() = default; // consume the current ubatch from the state and proceed to the next one @@ -69,8 +56,63 @@ public: // get the current ubatch virtual const llama_ubatch & get_ubatch() const = 0; - // get the status of the memory state + // get the status of the memory state - used for error handling and checking if any updates would be applied virtual llama_memory_status get_status() const = 0; }; using llama_memory_state_ptr = std::unique_ptr; + +// general concept of LLM memory +// the KV cache is a type of LLM memory, but there can be other types +struct llama_memory_i { + virtual ~llama_memory_i() = default; + + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; + + // prepare for any pending memory updates, such as shifts, defrags, etc. + // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update + virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; + + // getters + virtual bool get_can_shift() const = 0; + + // + // ops + // + + // if data == true, the data buffers will also be cleared together with the metadata + virtual void clear(bool data) = 0; + + virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; + virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; + virtual void seq_keep(llama_seq_id seq_id) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; + virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + + virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; + virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; + + // + // state write/read + // + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; +}; + +using llama_memory_ptr = std::unique_ptr; + +// TODO: temporary until the llama_kv_cache is removed from the public API +struct llama_kv_cache : public llama_memory_i { + virtual ~llama_kv_cache() = default; +}; diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index 9da97f1b..47497cf9 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -401,7 +401,7 @@ struct llama_mmap::impl { } } #else - throw std::runtime_error("PrefetchVirtualMemory unavailable"); + LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n"); #endif } } diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index ddb1b036..bd9e6da8 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -288,9 +288,10 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const gguf_context * ctx = meta.get(); + const int kid = gguf_find_key(ctx, key.c_str()); - if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } @@ -298,28 +299,40 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || - (std::is_same::value)); break; - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); } - result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + if constexpr (std::is_same::value) { + const size_t n_items = gguf_get_arr_n(ctx, kid); + result.clear(); + + for (size_t i = 0; i < n_items; i++) { + const T value = gguf_get_arr_str(ctx, kid, i); + result.emplace_back(value); + } + } else { + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + } return true; } template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const gguf_context * ctx = meta.get(); + const int kid = gguf_find_key(ctx, key.c_str()); - if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } @@ -327,22 +340,32 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || - (std::is_same::value)); break; - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); } if (arr_info.length > N_MAX) { throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); } - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if constexpr (std::is_same::value) { + const size_t n_items = gguf_get_arr_n(ctx, kid); + + for (size_t i = 0; i < n_items; i++) { + const T value = gguf_get_arr_str(ctx, kid, i); + result[i] = value; + } + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } return true; } @@ -352,6 +375,8 @@ namespace GGUFMeta { return get_arr(llm_kv(kid), result, required); } + template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + template bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { auto it = kv_overrides.find(key); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 50264a69..c41ee245 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { uint32_t n_vocab = 0; ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + // for classifier models + ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false); + if (!classifier_labels.empty()) { + hparams.n_cls_out = classifier_labels.size(); + } + // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: @@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false); switch (hparams.n_layer) { case 3: @@ -956,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 46: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; case LLM_ARCH_GEMMA3: { @@ -976,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); @@ -4356,6 +4367,15 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + + if (!classifier_labels.empty()) { + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + + size_t i = 0; + for (auto label : classifier_labels) { + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + } + } } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); @@ -8484,14 +8504,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - switch (model.type) { - case LLM_TYPE_2B: - case LLM_TYPE_9B: - case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break; - default: GGML_ABORT("fatal error"); - }; - cb(Qcur, "Qcur_scaled", il); + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, @@ -8632,9 +8645,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, @@ -13600,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } +uint32_t llama_model_n_cls_out(const struct llama_model * model) { + return model->hparams.n_cls_out; +} + +const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) { + if (i < model->classifier_labels.size()) { + return model->classifier_labels[i].c_str(); + } + + return nullptr; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); @@ -13760,7 +13788,7 @@ uint64_t llama_model_size(const llama_model * model) { } const char * llama_model_chat_template(const llama_model * model, const char * name) { - const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE) : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index cbea2cb3..18b71462 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -329,6 +329,9 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + // for classifier models + std::vector classifier_labels; + struct ggml_tensor * tok_embd = nullptr; struct ggml_tensor * type_embd = nullptr; struct ggml_tensor * pos_embd = nullptr; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index d5a036a8..ba2e1864 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { std::string model_name; std::string tokenizer_pre; + std::string general_arch; ml.get_key(LLM_KV_GENERAL_NAME, model_name, false); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false); // model name to lowercase std::transform(model_name.begin(), model_name.end(), model_name.begin(), @@ -2091,9 +2093,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } ); - // set attributes by model/tokenizer name - if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) { - _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); + // set attributes by model/tokenizer/architecture name + if (false + || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"}) + || _contains_any(general_arch, {"nomic-bert-moe"}) + ) { + if (token_to_id.count("") == 0) { + LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__); + } else { + _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } else if (_contains_any(model_name, {"phi-3", "phi3"})) { for (auto id : cache_special_tokens) { _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true); diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index da0f652c..015a5789 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -61,7 +61,10 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_kv_cache; + + typedef struct llama_memory_i * llama_memory_t; + + struct llama_kv_cache; // DEPRECATED (use llama_memory instead) typedef int32_t llama_pos; typedef int32_t llama_token; @@ -493,9 +496,11 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); - LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); + LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type + DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @@ -509,6 +514,13 @@ extern "C" { // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + // Returns the number of classifier outputs (only valid for classifier models) + // Undefined behavior for non-classifier models + LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model); + + // Returns label of classifier output by index ( 1` + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + + // Returns the smallest position present in the memory for the specified sequence + // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id); + + // Returns the largest position present in the memory for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id); + + // Check if the memory supports shifting + LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + + // + // KV cache for self-attention (TODO: deprecate in favor of llama_memory) // // Returns the number of tokens in the KV cache (slow, use only for debug) @@ -622,86 +708,95 @@ extern "C" { "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Clear the KV cache - both cell info is erased and KV data is zeroed - LLAMA_API void llama_kv_self_clear( - struct llama_context * ctx); + DEPRECATED(LLAMA_API void llama_kv_self_clear( + struct llama_context * ctx), + "Use llama_memory_clear() instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_self_seq_rm( + DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_rm() instead"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_cp( + DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_cp() instead"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_self_seq_keep( + DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_keep() instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_add( + DEPRECATED(LLAMA_API void llama_kv_self_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, - llama_pos delta); + llama_pos delta), + "Use llama_memory_seq_add() instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_div( + DEPRECATED(void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, - int d); + int d), + "Use llama_memory_seq_div() instead"); // Returns the smallest position present in the KV cache for the specified sequence // This is typically non-zero only for SWA caches // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty - LLAMA_API llama_pos llama_kv_self_seq_pos_min( + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_pos_min() instead"); // Returns the largest position present in the KV cache for the specified sequence // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache // Return -1 if the sequence is empty - LLAMA_API llama_pos llama_kv_self_seq_pos_max( + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_pos_max() instead"); // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() - LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), + DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); // Check if the context supports KV cache shifting - LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); + DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), + "use llama_memory_can_shift() instead"); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), + DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), "simply remove this call, updates are applied lazily on the next llama_decode()"); // @@ -709,7 +804,7 @@ extern "C" { // // Returns the *actual* size in bytes of the state - // (logits, embedding and kv_cache) + // (logits, embedding and memory) // Only use when saving the state, not when restoring it, otherwise the size may be too small. LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), @@ -765,12 +860,12 @@ extern "C" { size_t n_token_count), "use llama_state_save_file instead"); - // Get the exact size needed to copy the KV cache of a single sequence + // Get the exact size needed to copy the state of a single sequence LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, llama_seq_id seq_id); - // Copy the KV cache of a single sequence into the specified buffer + // Copy the state of a single sequence into the specified buffer LLAMA_API size_t llama_state_seq_get_data( struct llama_context * ctx, uint8_t * dst, @@ -836,16 +931,16 @@ extern "C" { // For encode-decoder contexts, processes the batch using the encoder. // Can store the encoder output internally for later use by the decoder's cross-attention layers. // 0 - success - // < 0 - error. the KV cache state is restored to the state before this call + // < 0 - error. the memory state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); // Process a batch of tokens. - // Requires KV cache. + // Requires the context to have a memory. // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. - // Upon non-zero return values, the KV cache state is restored to the state before this call + // Upon non-zero return values, the memory state is restored to the state before this call // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 2 - aborted @@ -916,7 +1011,7 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);