diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index a88b2fe3..b98e3256 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -1,5 +1,6 @@ #include "llama-batch.h" +#include #include #include @@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 batch = in_batch; GGML_ASSERT(batch.n_tokens > 0); if (!batch.pos) { + assert(p0 >= 0); pos.resize(batch.n_tokens); for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = i + p0; + pos[i] = p0 + i; } batch.pos = pos.data(); } diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index a3b84a6a..e153351a 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -25,7 +25,11 @@ llama_context::llama_context( const auto & hparams = model.hparams; - cparams.n_seq_max = std::max(1u, params.n_seq_max); + cparams.n_seq_max = std::max(1u, params.n_seq_max); + if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) { + throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES)); + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -93,6 +97,7 @@ llama_context::llama_context( } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.op_offload = params.op_offload; const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -176,8 +181,9 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -687,12 +693,18 @@ int llama_context::encode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int32_t i = 0; i < n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; } + + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + throw -1; + } } } @@ -846,7 +858,7 @@ int llama_context::encode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) { if (!memory) { - LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(inp_batch); } @@ -855,11 +867,17 @@ int llama_context::decode(llama_batch & inp_batch) { return -1; } + if (!inp_batch.pos) { + if (inp_batch.seq_id) { + LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__); + return -1; + } + } + llama_kv_cache * kv_self = static_cast(memory.get()); // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); const llama_batch & batch = batch_allocr.batch; @@ -875,11 +893,17 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + // TODO: move the validation to the llama_batch_allocr if (batch.token) { for (int64_t i = 0; i < n_tokens_all; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); - throw std::runtime_error("invalid token"); + return -1; + } + + if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES); + return -1; } } } @@ -947,8 +971,6 @@ int llama_context::decode(llama_batch & inp_batch) { // find KV slot if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - return 1; } @@ -2093,6 +2115,7 @@ llama_context_params llama_context_default_params() { /*.flash_attn =*/ false, /*.no_perf =*/ true, /*.op_offload =*/ true, + /*.swa_full =*/ true, }; return result; @@ -2287,65 +2310,51 @@ int32_t llama_apply_adapter_cvec( return res ? 0 : -1; } -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { - const auto * kv = ctx->get_kv_self(); - if (kv == nullptr) { - LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); - return {}; - } - - return llama_kv_cache_view_init(*kv, n_seq_max); -} - -void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { - const auto * kv = ctx->get_kv_self(); - if (kv == nullptr) { - LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); - return; - } - - llama_kv_cache_view_update(view, kv); -} - // // kv cache // // deprecated -int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { - return llama_kv_self_n_tokens(ctx); -} - int32_t llama_kv_self_n_tokens(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { return 0; } - return kv->get_n_tokens(); + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; } // deprecated -int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { - return llama_kv_self_used_cells(ctx); -} - +// note: this is the same as above - will be removed anyway, so it's ok int32_t llama_kv_self_used_cells(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { return 0; } - return kv->get_used_cells(); -} + int32_t res = 0; -// deprecated -void llama_kv_cache_clear(llama_context * ctx) { - llama_kv_self_clear(ctx); + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; } void llama_kv_self_clear(llama_context * ctx) { @@ -2357,15 +2366,6 @@ void llama_kv_self_clear(llama_context * ctx) { kv->clear(); } -// deprecated -bool llama_kv_cache_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); -} - bool llama_kv_self_seq_rm( llama_context * ctx, llama_seq_id seq_id, @@ -2379,16 +2379,6 @@ bool llama_kv_self_seq_rm( return kv->seq_rm(seq_id, p0, p1); } -// deprecated -void llama_kv_cache_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); -} - void llama_kv_self_seq_cp( llama_context * ctx, llama_seq_id seq_id_src, @@ -2403,13 +2393,6 @@ void llama_kv_self_seq_cp( kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } -// deprecated -void llama_kv_cache_seq_keep( - llama_context * ctx, - llama_seq_id seq_id) { - llama_kv_self_seq_keep(ctx, seq_id); -} - void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2419,16 +2402,6 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { kv->seq_keep(seq_id); } -// deprecated -void llama_kv_cache_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); -} - void llama_kv_self_seq_add( llama_context * ctx, llama_seq_id seq_id, @@ -2443,16 +2416,6 @@ void llama_kv_self_seq_add( kv->seq_add(seq_id, p0, p1, delta); } -// deprecated -void llama_kv_cache_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); -} - void llama_kv_self_seq_div( llama_context * ctx, llama_seq_id seq_id, @@ -2467,25 +2430,24 @@ void llama_kv_self_seq_div( kv->seq_div(seq_id, p0, p1, d); } -// deprecated -llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_pos_max(ctx, seq_id); +llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { + const auto * kv = ctx->get_kv_self(); + if (!kv) { + return -1; + } + + return kv->seq_pos_min(seq_id); } llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { const auto * kv = ctx->get_kv_self(); if (!kv) { - return 0; + return -1; } return kv->seq_pos_max(seq_id); } -// deprecated -void llama_kv_cache_defrag(llama_context * ctx) { - llama_kv_self_defrag(ctx); -} - void llama_kv_self_defrag(llama_context * ctx) { auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2496,11 +2458,6 @@ void llama_kv_self_defrag(llama_context * ctx) { kv->defrag_sched(-1.0f); } -// deprecated -bool llama_kv_cache_can_shift(const llama_context * ctx) { - return llama_kv_self_can_shift(ctx); -} - bool llama_kv_self_can_shift(const llama_context * ctx) { const auto * kv = ctx->get_kv_self(); if (!kv) { @@ -2510,11 +2467,6 @@ bool llama_kv_self_can_shift(const llama_context * ctx) { return kv->get_can_shift(); } -// deprecated -void llama_kv_cache_update(llama_context * ctx) { - llama_kv_self_update(ctx); -} - // llama state API // deprecated @@ -2637,7 +2589,21 @@ int32_t llama_encode( int32_t llama_decode( llama_context * ctx, llama_batch batch) { - const int ret = ctx->decode(batch); + int ret = ctx->decode(batch); + + // defrag and try again + // TODO: distinguish return code when we are sure that even after defrag there is no space available + if (ret == 1) { + llama_kv_self_defrag(ctx); + ret = ctx->decode(batch); + + if (ret == 1) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens); + + return ret; + } + } + if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } diff --git a/examples/talk-llama/llama-cparams.cpp b/examples/talk-llama/llama-cparams.cpp index 28369be3..f7b36590 100644 --- a/examples/talk-llama/llama-cparams.cpp +++ b/examples/talk-llama/llama-cparams.cpp @@ -1 +1,5 @@ #include "llama-cparams.h" + +size_t llama_max_parallel_sequences(void) { + return LLAMA_MAX_PARALLEL_SEQUENCES; +} diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 246fa577..2871031e 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -4,6 +4,8 @@ #include +#define LLAMA_MAX_PARALLEL_SEQUENCES 64 + struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 973b47ae..bed706bb 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token for (const auto & trigger_pattern : grammar.trigger_patterns) { if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { grammar.awaiting_trigger = false; - // get from the first match to the end of the string - auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + auto constrained_str = grammar.trigger_buffer.substr(start); // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index b0e3f635..cdd5887d 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -9,33 +9,6 @@ #include #include -static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { - // TODO move to hparams if a T5 variant appears that uses a different value - const int64_t max_distance = 128; - - if (bidirectional) { - n_buckets >>= 1; - } - - const int64_t max_exact = n_buckets >> 1; - - int32_t relative_position = x - y; - int32_t relative_bucket = 0; - - if (bidirectional) { - relative_bucket += (relative_position > 0) * n_buckets; - relative_position = abs(relative_position); - } else { - relative_position = -std::min(relative_position, 0); - } - - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); - relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); - - return relative_bucket; -} - void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - const int64_t n_tokens = ubatch->n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing - - int32_t * data = (int32_t *) pos_bucket->data; - - const int64_t n_kv = kv_self->n; - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); - } - } - } + kv_self->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask || self_kq_mask_swa) { - const int64_t n_kv = kv_self->n; - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; + if (self_kq_mask) { + kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } +} - float * data = nullptr; - float * data_swa = nullptr; +void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask) { + kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - if (self_kq_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - data = (float *) self_kq_mask->data; - } - - if (self_kq_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - data_swa = (float *) self_kq_mask_swa->data; - } - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; - for (int i = 0; i < n_kv; ++i) { - float f; - // mask the token if: - if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence - || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens - ) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(kv_self->cells[i].pos - pos); - } else { - f = 0.0f; - } - } - - if (data) { - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - - // may need to cut off old tokens for sliding window - // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask" - if (data_swa) { - if (hparams.n_attn_chunk) { - llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; - if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { - f = -INFINITY; - } - } else { - if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - } - data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - } - - // mask padded tokens - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - - // mask padded tokens - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - } + if (self_kq_mask_swa) { + kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : n_layer (hparams.n_layer), n_rot (hparams.n_rot), n_ctx (cparams.n_ctx), - n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max), n_head (hparams.n_head()), n_head_kv (hparams.n_head_kv()), n_embd_head_k (hparams.n_embd_head_k), @@ -1153,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { auto inp = std::make_unique(hparams, kv_self); - const auto n_kv = kv_self->n; + const auto n_kv = kv_self->get_n(); auto & cur = inp->pos_bucket; @@ -1188,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kq_b, ggml_tensor * kq_mask, ggml_tensor * v_mla, - bool v_trans, float kq_scale) const { - //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const bool v_trans = v->nb[1] > v->nb[2]; - //const int64_t n_head = hparams.n_head(il); - //const int64_t n_head_kv = hparams.n_head_kv(il); - - //const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); const auto n_tokens = q->ne[1]; const auto n_head = q->ne[2]; @@ -1336,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); - //cb(k, "k", il); - - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); + ggml_tensor * q = q_cur; + ggml_tensor * k = k_cur; + ggml_tensor * v = v_cur; + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1369,22 +1235,16 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() auto inp = std::make_unique(hparams, cparams, kv_self); - const auto n_kv = kv_self->n; + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); - ggml_set_input(inp->self_kq_mask); + const auto n_kv = kv_self->get_n(); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); - if (hparams.n_swa_pattern > 1) { - GGML_ASSERT(hparams.n_swa > 0); - - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); - ggml_set_input(inp->self_kq_mask_swa); - - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); @@ -1409,81 +1269,104 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, v_cur); const llama_kv_cache_unified * kv_self = static_cast(memory); - const auto & n_ctx = cparams.n_ctx; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - const auto n_tokens = q_cur->ne[2]; - - const bool v_trans = !cparams.flash_attn; // store to KV cache { - const auto kv_head = kv_self->head; - - GGML_ASSERT(kv_self->size == n_ctx); - - ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head); - //cb(k_cache_view, "k_cache_view", il); - - // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); - - v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens); - - ggml_tensor * v_cache_view = nullptr; - - if (!v_trans) { - v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head); - } else { - // note: the V cache is transposed when not using flash attention - v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv_self->v_l[il]), - (kv_head)*ggml_element_size(kv_self->v_l[il])); - - v_cur = ggml_transpose(ctx0, v_cur); - } - //cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il)); } + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_self->get_k(ctx0, il); + ggml_tensor * v = kv_self->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { + const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + + auto inp = std::make_unique(hparams, cparams, kv_self); + + { + const auto n_kv = kv_self->get_kv_base()->get_n(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + { + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); + + const auto n_kv = kv_self->get_kv_swa()->get_n(); + + inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } + + return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + const bool is_swa = hparams.is_swa(il); + const llama_kv_cache_unified_iswa * kv_self = static_cast(memory); + + const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il)); + } + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); - const auto n_kv = kv_self->n; + ggml_tensor * q = q_cur; + ggml_tensor * k = kv->get_k(ctx0, il); + ggml_tensor * v = kv->get_v(ctx0, il); - const int64_t n_head_kv = hparams.n_head_kv(il); - - const auto & n_embd_head_k = hparams.n_embd_head_k; - const auto & n_embd_head_v = hparams.n_embd_head_v; - - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = - ggml_view_3d(ctx0, kv_self->k_l[il], - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), - 0); - //cb(k, "k", il); - - ggml_tensor * v = !v_trans ? - ggml_view_3d(ctx0, kv_self->v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v), - 0) : - ggml_view_3d(ctx0, kv_self->v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self->v_l[il])*n_ctx, - ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, - 0); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1534,17 +1417,11 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask_cross(); - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); - //cb(k, "k", il); - - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale); + ggml_tensor * q = q_cur; + ggml_tensor * k = k_cur; + ggml_tensor * v = v_cur; + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1712,3 +1589,30 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } + +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + + return relative_bucket; +} diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 832a8c09..2b85bb25 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -19,6 +19,7 @@ struct llama_cparams; class llama_memory_i; class llama_kv_cache_unified; +class llama_kv_cache_unified_iswa; class llama_kv_cache_recurrent; // certain models (typically multi-modal) can produce different types of graphs @@ -255,6 +256,31 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_unified * kv_self; +}; + +class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { +public: + llm_graph_input_attn_kv_unified_iswa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_iswa * kv_self) : + hparams(hparams), + cparams(cparams), + kv_self(kv_self) { + } + ~llm_graph_input_attn_kv_unified_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } @@ -266,7 +292,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_iswa * kv_self; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -378,7 +404,6 @@ struct llm_graph_context { const int64_t n_layer; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) - const int64_t n_ctx_per_seq; const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; @@ -507,13 +532,12 @@ struct llm_graph_context { ggml_tensor * build_attn_mha( ggml_cgraph * gf, - ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] - ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] - ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) + ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - bool v_trans, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale) const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; @@ -546,6 +570,21 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( @@ -596,3 +635,6 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; }; + +// TODO: better name +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 90dfe7a7..1499eb08 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -2,6 +2,22 @@ #include "ggml.h" +void llama_hparams::set_swa_pattern(uint32_t n_pattern) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } +} + +bool llama_hparams::is_swa_any() const { + for (uint32_t il = 0; il < n_layer; ++il) { + if (swa_layers[il]) { + return true; + } + } + + return false; +} + uint32_t llama_hparams::n_head(uint32_t il) const { if (il < n_layer) { return n_head_arr[il]; @@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); + return swa_layers[il]; } GGML_ABORT("fatal error"); diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 7ee6a5b7..2d72eab1 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -14,6 +14,12 @@ enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, }; +enum llama_swa_type { + LLAMA_SWA_TYPE_NONE = 0, + LLAMA_SWA_TYPE_STANDARD = 1, + LLAMA_SWA_TYPE_CHUNKED = 2, +}; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -35,8 +41,6 @@ struct llama_hparams { uint32_t n_embd_features = 0; uint32_t n_layer; uint32_t n_rot; - uint32_t n_swa = 0; // sliding window attention (SWA) - uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; @@ -96,6 +100,15 @@ struct llama_hparams { std::array rope_sections; + // Sliding Window Attention (SWA) + llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + // the size of the sliding window (0 - no SWA) + uint32_t n_swa = 0; + // if swa_layers[il] == true, then layer il is SWA + // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // by default, all layers are dense + std::array swa_layers; + // for State Space Models uint32_t ssm_d_conv = 0; uint32_t ssm_d_inner = 0; @@ -116,11 +129,10 @@ struct llama_hparams { bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; + bool use_kq_norm = true; + // llama4 uint32_t n_moe_layer_step = 0; - bool use_kq_norm = true; - uint32_t n_attn_chunk = 0; - // values below seems to be fixed on llama4 uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; float f_attn_temp_scale = 0.1; @@ -133,6 +145,23 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // note that if n_pattern == 0, all layers are SWA + // if n_pattern == 1, all layers are dense + // example: n_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern); + + // return true if one of the layers is SWA + bool is_swa_any() const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 265db252..4a42d6ec 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -23,32 +23,21 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { } llama_kv_cache_unified::llama_kv_cache_unified( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) { - const int32_t n_layer = hparams.n_layer; + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { - has_shift = false; - can_shift = true; - - LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n", - __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding); - - GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding"); - - head = 0; - size = kv_size; - used = 0; - - this->type_k = type_k; - this->type_v = type_v; - - cells.clear(); - cells.resize(kv_size); + GGML_ASSERT(kv_size % n_pad == 0); // create a context for each buffer type std::map ctx_map; @@ -56,7 +45,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -75,37 +64,48 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - k_l.reserve(n_layer); - v_l.reserve(n_layer); + head = 0; - for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + cells.resize(kv_size); + + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); const char * dev_name = "CPU"; ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); if (offload) { - auto * dev = model.dev_layer(i); + auto * dev = model.dev_layer(il); buft = ggml_backend_dev_buffer_type(dev); dev_name = ggml_backend_dev_name(dev); } - LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name); + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + ggml_tensor * k; + ggml_tensor * v; + + k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + ggml_format_name(k, "cache_k_l%d", il); + ggml_format_name(v, "cache_v_l%d", il); + + map_layer_ids[il] = layers.size(); + layers.push_back({ il, k, v }); } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -117,8 +117,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); } - ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ggml_backend_buffer_clear(buf, 0); bufs.emplace_back(buf); } @@ -126,20 +128,17 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } } void llama_kv_cache_unified::clear() { - for (int32_t i = 0; i < (int32_t) size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - } + cells.reset(); + head = 0; - used = 0; for (auto & buf : bufs) { ggml_backend_buffer_clear(buf.get(), 0); @@ -147,7 +146,7 @@ void llama_kv_cache_unified::clear() { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; + uint32_t new_head = cells.size(); if (p0 < 0) { p0 = 0; @@ -157,32 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - cells[i].pos = -1; - - if (new_head == size) { - new_head = i; - } + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; } } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { + if (new_head != cells.size() && new_head < head) { head = new_head; } @@ -202,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id p1 = std::numeric_limits::max(); } - // otherwise, this is the KV of a Transformer-like model - head = 0; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { - cells[i].seq_id.insert(seq_id_dst); + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); } } } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; + uint32_t new_head = cells.size(); - for (uint32_t i = 0; i < size; ++i) { - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { new_head = i; } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { + if (new_head != cells.size() && new_head < head) { head = new_head; } } -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { return; } - uint32_t new_head = size; + uint32_t new_head = cells.size(); if (p0 < 0) { p0 = 0; @@ -254,24 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po p1 = std::numeric_limits::max(); } - // If there is no range then return early to avoid looping over the + // If there is no range then return early to avoid looping over all cells. if (p0 == p1) { return; } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; - cells[i].pos += delta; - cells[i].delta += delta; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - if (cells[i].pos < 0) { - if (!cells[i].is_empty()) { - used--; - } - cells[i].pos = -1; - cells[i].seq_id.clear(); - if (new_head == size) { + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, shift)) { + if (new_head == cells.size()) { new_head = i; } } @@ -280,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po // If we freed up a slot, set head to it so searching can start there. // Otherwise we just start the next search from the beginning. - head = new_head != size ? new_head : 0; + head = new_head != cells.size() ? new_head : 0; } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { @@ -301,66 +274,41 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po return; } - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } - { - llama_pos p_old = cells[i].pos; - cells[i].pos /= d; - cells[i].delta += cells[i].pos - p_old; - } + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); } } } +llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + return cells.seq_pos_min(seq_id); +} + llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = 0; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); - } - } - - return result; + return cells.seq_pos_max(seq_id); } void llama_kv_cache_unified::restore() { - if (pending.ranges.empty()) { - return; + for (auto & state : recovery.states) { + cells.set(state.i, state.cells); } - uint32_t new_head = size; - - for (auto & range : pending.ranges) { - for (uint32_t i = range.c0; i < range.c1; ++i) { - cells[i].seq_id.clear(); - - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - } - - new_head = std::min(new_head, range.c0); - } - - if (new_head != size && new_head < head) { - head = new_head; - } + recovery.clear(); } void llama_kv_cache_unified::commit() { - if (pending.ranges.empty()) { - LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); + if (recovery.states.empty()) { + LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); return; } - pending.ranges.clear(); + recovery.clear(); } bool llama_kv_cache_unified::update(llama_context & lctx) { @@ -368,7 +316,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * sched = lctx.get_sched(); - if (has_shift) { + if (cells.get_has_shift()) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } @@ -392,13 +340,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { need_reserve = true; } - { - has_shift = false; - - for (uint32_t i = 0; i < size; ++i) { - cells[i].delta = 0; - } - } + cells.reset_shift(); } if (do_defrag) { @@ -429,7 +371,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { void llama_kv_cache_unified::defrag_sched(float thold) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > thold) { @@ -440,7 +382,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) { } void llama_kv_cache_unified::set_full() { - n = size; + n = cells.size(); // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. @@ -450,51 +392,67 @@ void llama_kv_cache_unified::set_full() { head = 0; } -llama_sbatch llama_kv_cache_unified::sbatch_init( - const llama_batch & batch, - bool logits_all) { +llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) { return llama_sbatch(batch, hparams.n_embd, true, logits_all); } -llama_ubatch llama_kv_cache_unified::ubatch_next( - llama_sbatch & sbatch, - uint32_t n_ubatch, - bool embd_pooled) const { +llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { GGML_UNUSED(embd_pooled); return sbatch.split_simple(n_ubatch); } -bool llama_kv_cache_unified::find_slot( - const llama_ubatch & ubatch) { +bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { + if (head > cells.get_used() + 2*ubatch.n_tokens) { head = 0; } // otherwise, one cell per token. - if (n_tokens > size) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); return false; } +//#define FIND_SLOT_DEBUG 1 +#if FIND_SLOT_DEBUG + LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); + + // for debugging + { + std::string ss; + if (n_swa > 0) { + for (uint32_t i = 0; i < size; ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + ss += 'x'; + } + if (i%256 == 255) { + ss += '\n'; + } + } + } + LLAMA_LOG_WARN("\n%s\n", ss.c_str()); + } +#endif + uint32_t n_tested = 0; while (true) { - if (head + n_tokens > size) { - n_tested += size - head; + if (head + n_tokens > cells.size()) { + n_tested += cells.size() - head; head = 0; continue; } bool found = true; for (uint32_t i = 0; i < n_tokens; i++) { - if (cells[head + i].pos >= 0) { + // TODO: improve to accept cells that are masked by the SWA + if (!cells.is_empty(head + i)) { found = false; head += i + 1; n_tested += i + 1; @@ -506,66 +464,257 @@ bool llama_kv_cache_unified::find_slot( break; } - if (n_tested >= size) { + if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return false; } } - for (uint32_t s = 0; s < n_seqs; s++) { - for (uint32_t i = 0; i < n_seq_tokens; ++i) { - uint32_t k = s*n_seq_tokens + i; - cells[head + k].pos = ubatch.pos[k]; + // store the old state of the cells in the recovery stack + recovery.states.push_back({head, cells.cp(head, n_tokens)}); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { - cells[head + k].seq_id.insert(ubatch.seq_id[s][j]); - } + for (uint32_t i = 0; i < n_tokens; ++i) { + cells.pos_set(head + i, ubatch.pos[i]); + + for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { + cells.seq_add(head + i, ubatch.seq_id[i][j]); } } - used += n_tokens; - - pending.ranges.push_back({head, head + n_tokens}); - // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); + n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); - //printf("n = %5d, used = %5d, head = %5d\n", n, used, head); +#ifdef FIND_SLOT_DEBUG + LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); +#endif return true; } -int32_t llama_kv_cache_unified::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_unified::get_used_cells() const { - return used; -} - bool llama_kv_cache_unified::get_can_shift() const { - return can_shift; + return true; } -llama_pos llama_kv_cache_unified::get_pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); +uint32_t llama_kv_cache_unified::get_n() const { + return n; +} + +uint32_t llama_kv_cache_unified::get_size() const { + return cells.size(); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + return ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + 0); +} + +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + if (!v_trans) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); } - return pos_max; + // note: v->nb[1] > v->nb[2] + return ggml_view_3d(ctx, v, + n, hparams.n_head_kv(il), hparams.n_embd_head_v, + ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, v->ne[1]), // v->nb[2] + 0); +} + +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + const int64_t n_tokens = k_cur->ne[2]; + + ggml_tensor * k_view = ggml_view_1d(ctx, k, + n_tokens*hparams.n_embd_k_gqa(il), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head); + + return ggml_cpy(ctx, k_cur, k_view); +} + +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + const int64_t n_tokens = v_cur->ne[2]; + + v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + + ggml_tensor * v_view = nullptr; + + if (!v_trans) { + v_view = ggml_view_1d(ctx, v, + n_tokens*hparams.n_embd_v_gqa(il), + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head); + } else { + // note: the V cache is transposed when not using flash attention + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + (v->ne[1])*ggml_element_size(v), + ( head)*ggml_element_size(v)); + + v_cur = ggml_transpose(ctx, v_cur); + } + + return ggml_cpy(ctx, v_cur, v_view); +} + +void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) { + // no pruning is needed when the cache does not use SWA + GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache"); + + int n_attended = 0; + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.seq_has(i, seq_id)) { + continue; + } + + const llama_pos p0 = cells.pos_get(i); + + if (p0 <= pmin && !is_masked_swa(p0, pmin)) { + n_attended++; + } + + if (is_masked_swa(p0, pmax)) { + cells.seq_rm(i, seq_id); + } + } + + if (n_attended < std::min(n_swa, pmin)) { + LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa); + } +} + +void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const int64_t n_tokens = ubatch->n_tokens; + const int64_t n_seq_tokens = ubatch->n_seq_tokens; + const int64_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const int64_t n_kv = n; + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(i); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + + // mask padded tokens + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } +} + +void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) dst->data; + + const int64_t n_kv = n; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } } size_t llama_kv_cache_unified::total_size() const { size_t size = 0; + for (const auto & buf : bufs) { size += ggml_backend_buffer_get_size(buf.get()); } @@ -576,8 +725,8 @@ size_t llama_kv_cache_unified::total_size() const { size_t llama_kv_cache_unified::size_k_bytes() const { size_t size_k_bytes = 0; - for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); + for (const auto & layer : layers) { + size_k_bytes += ggml_nbytes(layer.k); } return size_k_bytes; @@ -586,8 +735,8 @@ size_t llama_kv_cache_unified::size_k_bytes() const { size_t llama_kv_cache_unified::size_v_bytes() const { size_t size_v_bytes = 0; - for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); + for (const auto & layer : layers) { + size_v_bytes += ggml_nbytes(layer.v); } return size_v_bytes; @@ -651,13 +800,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); if (k_shift) { - assert(ggml_backend_buffer_is_host(k_shift->buffer)); - - int32_t * data = (int32_t *) k_shift->data; - - for (uint32_t i = 0; i < kv_self->size; ++i) { - data[i] = kv_self->cells[i].delta; - } + kv_self->set_input_k_shift(k_shift); } } @@ -667,13 +810,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_cgraph * gf) const { auto res = std::make_unique(); - const auto & n_layer = hparams.n_layer; - const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; - const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; - //GGML_ASSERT(kv_self->size == n_ctx); auto inp = std::make_unique(this); @@ -681,24 +820,22 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); ggml_set_input(inp->k_shift); - for (uint32_t il = 0; il < n_layer; ++il) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const bool is_swa = hparams.is_swa(il); + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - // note: the swa rope params could become part of the cparams in the future - // if we decide to make them configurable, like the non-sliding ones - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); ggml_tensor * k = - ggml_view_3d(ctx, k_l[il], - n_embd_head_k, n_head_kv, size, - ggml_row_size(k_l[il]->type, n_embd_head_k), - ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_view_3d(ctx, layer.k, + n_embd_head_k, n_head_kv, cells.size(), + ggml_row_size(layer.k->type, n_embd_head_k), + ggml_row_size(layer.k->type, n_embd_k_gqa), 0); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); @@ -803,44 +940,46 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( nm++; } - for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], + ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, n_embd_k_gqa, nm, - ggml_row_size(k_l[il]->type, n_embd_k_gqa), - ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*i)); - ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], + ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, n_embd_k_gqa, nm, - ggml_row_size(k_l[il]->type, n_embd_k_gqa), - ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*id)); ggml_tensor * view_v_src; ggml_tensor * view_v_dst; if (cparams.flash_attn) { // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, v_l[il], + view_v_src = ggml_view_2d(ctx, layer.v, n_embd_v_gqa, nm, - ggml_row_size(v_l[il]->type, n_embd_v_gqa), - ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*i)); - view_v_dst = ggml_view_2d(ctx, v_l[il], + view_v_dst = ggml_view_2d(ctx, layer.v, n_embd_v_gqa, nm, - ggml_row_size(v_l[il]->type, n_embd_v_gqa), - ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*id)); } else { - view_v_src = ggml_view_2d(ctx, v_l[il], + view_v_src = ggml_view_2d(ctx, layer.v, nm, n_embd_v_gqa, - ggml_row_size(v_l[il]->type, size), - ggml_row_size(v_l[il]->type, i)); + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, i)); - view_v_dst = ggml_view_2d(ctx, v_l[il], + view_v_dst = ggml_view_2d(ctx, layer.v, nm, n_embd_v_gqa, - ggml_row_size(v_l[il]->type, size), - ggml_row_size(v_l[il]->type, id)); + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, id)); } ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); @@ -857,10 +996,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( } bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = hparams.n_layer; + const uint32_t n_layer = layers.size(); - const uint32_t n_kv = cell_max(); - const uint32_t n_used = used; + const uint32_t n_kv = cells.used_max_p1(); + const uint32_t n_used = cells.get_used(); assert(n_used <= n_kv); @@ -888,9 +1027,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { ids.resize(n_kv, n_kv); for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = cells[i0]; - - if (!cell0.is_empty()) { + if (!cells.is_empty(i0)) { ids[i0] = i0; continue; @@ -901,7 +1038,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { uint32_t nh = 1; // determine the size of the hole - while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { nh++; } @@ -910,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // starting from the end, find nh non-empty cells for (; is > i0; --is) { - const auto & cell1 = cells[is]; - - if (cell1.is_empty() || ids[is] != n_kv) { + if (cells.is_empty(is) || ids[is] != n_kv) { continue; } @@ -939,9 +1074,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { // go back and move the nf cells to the hole for (; i1 < n_kv; ++i1) { - auto & cell1 = cells[i1]; - - if (cell1.is_empty() || ids[i1] != n_kv) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { if (n_moves == max_moves) { stop = true; break; @@ -955,10 +1088,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { ids[i1] = i0 + nf; // move the cell meta data - cells[i0 + nf] = cell1; + cells.mv(i1, i0 + nf); - // clear the old cell and move the head there - cell1 = kv_cell(); head = n_used; if (!cont) { @@ -993,16 +1124,30 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return true; } -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; +bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; } - return 0; + return false; } void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { @@ -1011,23 +1156,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq // Count the number of cells with the specified seq_id // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { ++cell_count; - if (cell_range_begin == size) { + if (cell_range_begin == cells.size()) { cell_range_begin = i; } } else { - if (cell_range_begin != size) { + if (cell_range_begin != cells.size()) { cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; + cell_range_begin = cells.size(); } } } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); } // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count @@ -1064,17 +1210,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.pos_get(i); + const uint32_t n_seq_id = seq_ids.size(); io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); } } } @@ -1082,7 +1235,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { const uint32_t v_trans = this->v_trans ? 1 : 0; - const uint32_t n_layer = hparams.n_layer; + const uint32_t n_layer = layers.size(); io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); @@ -1091,56 +1244,63 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: // Iterate and write all the keys first, each row is a cell // Get whole range at a time - for (uint32_t il = 0; il < n_layer; ++il) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type - const int32_t k_type_i = (int32_t)k_l[il]->type; + const int32_t k_type_i = (int32_t)layer.k->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; - io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + io.write_tensor(layer.k, range.first * k_size_row, buf_size); } } if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; + const int32_t v_type_i = (int32_t)layer.v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; - io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + io.write_tensor(layer.v, range.first * v_size_row, buf_size); } } } else { // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; - for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t kv_size = cells.size(); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; + const int32_t v_type_i = (int32_t)layer.v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write element size - const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + const uint32_t v_size_el = ggml_type_size(layer.v->type); io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size @@ -1153,7 +1313,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; const size_t buf_size = range_size * v_size_el; - io.write_tensor(v_l[il], src_offset, buf_size); + io.write_tensor(layer.v, src_offset, buf_size); } } } @@ -1170,8 +1330,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; @@ -1180,32 +1338,40 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell io.read_to(&pos, sizeof(pos)); io.read_to(&n_seq_id, sizeof(n_seq_id)); - if (n_seq_id != 0) { + if (n_seq_id != 1) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); return false; } - batch.pos[i] = pos; + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + batch.pos[i] = pos; + batch.n_seq_id[i] = n_seq_id; + batch.seq_id[i] = &dest_seq_id; } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } + commit(); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + GGML_ASSERT(head + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head) == batch.pos[0]); + GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id)); } else { // whole KV cache restore - if (cell_count > size) { + if (cell_count > cells.size()) { LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); return false; } @@ -1213,34 +1379,28 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell clear(); for (uint32_t i = 0; i < cell_count; ++i) { - kv_cell & cell = cells[i]; - llama_pos pos; uint32_t n_seq_id; io.read_to(&pos, sizeof(pos)); io.read_to(&n_seq_id, sizeof(n_seq_id)); - cell.pos = pos; + cells.pos_set(i, pos); for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); - // TODO: llama_kv_cache_unified should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); return false; } - cell.seq_id.insert(seq_id); + cells.seq_add(i, seq_id); } } head = 0; - used = cell_count; } return true; @@ -1249,15 +1409,16 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t v_trans; uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); io.read_to(&n_layer, sizeof(n_layer)); - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + if (n_layer != layers.size()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); return false; } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); return false; } if (this->v_trans != (bool) v_trans) { @@ -1266,13 +1427,15 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell } // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (uint32_t il = 0; il < n_layer; ++il) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) k_l[il]->type; + const int32_t k_type_i = (int32_t) layer.k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return false; @@ -1281,7 +1444,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -1289,18 +1452,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); } } if (!this->v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; + const int32_t v_type_i = (int32_t)layer.v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1309,7 +1474,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -1317,18 +1482,20 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); } } } else { // For each layer, read the values for each cell (transposed) - for (uint32_t il = 0; il < n_layer; ++il) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; + const int32_t v_type_i = (int32_t)layer.v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1337,7 +1504,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t v_size_el_ref; io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size_el = ggml_type_size(layer.v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); return false; @@ -1354,8 +1521,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + const size_t dst_offset = (head + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } @@ -1364,6 +1531,193 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell return true; } +// +// llama_kv_cache_unified_iswa +// + +llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_batch, + uint32_t n_pad) : hparams(model.hparams) { + llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; + llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + + const uint32_t size_base = kv_size; + + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad)); + + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning + if (swa_full) { + LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + size_swa = size_base; + do_prune = false; + } + + LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + + kv_base = std::make_unique( + model, std::move(filter_base), type_k, type_v, + v_trans, offload, size_base, n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE); + + LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); + + kv_swa = std::make_unique( + model, std::move(filter_swa), type_k, type_v, + v_trans, offload, size_swa, n_seq_max, n_pad, + hparams.n_swa, hparams.swa_type); +} + +void llama_kv_cache_unified_iswa::clear() { + kv_base->clear(); + kv_swa ->clear(); +} + +bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_base->seq_rm(seq_id, p0, p1); + res = res & kv_swa ->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { + kv_base->seq_keep(seq_id); + kv_swa ->seq_keep(seq_id); +} + +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_base->seq_div(seq_id, p0, p1, d); + kv_swa ->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the base cache is a superset of the SWA cache, so we can just check the SWA cache + return kv_swa->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { + return kv_swa->seq_pos_max(seq_id); +} + +void llama_kv_cache_unified_iswa::restore() { + kv_base->restore(); + kv_swa ->restore(); +} + +void llama_kv_cache_unified_iswa::commit() { + kv_base->commit(); + kv_swa ->commit(); + + // slide the attention window, forgetting/pruning old tokens that are outside the window + if (do_prune) { + for (const auto & [seq_id, entry] : pending.pos) { + kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax); + } + + } + + pending.clear(); +} + +bool llama_kv_cache_unified_iswa::update(llama_context & lctx) { + bool res = true; + + res = res & kv_base->update(lctx); + res = res & kv_swa ->update(lctx); + + return res; +} + +void llama_kv_cache_unified_iswa::defrag_sched(float thold) { + kv_base->defrag_sched(thold); + kv_swa ->defrag_sched(thold); +} + +void llama_kv_cache_unified_iswa::set_full() { + kv_base->set_full(); + kv_swa ->set_full(); +} + +llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) { + pending.clear(); + + if (do_prune) { + for (int i = 0; i < batch.n_tokens; ++i) { + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + const llama_seq_id seq_id = batch.seq_id[i][s]; + const llama_pos pos = batch.pos[i]; + + if (pending.pos.find(seq_id) == pending.pos.end()) { + pending.pos[seq_id].pmin = pos; + pending.pos[seq_id].pmax = pos; + } else { + pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos); + pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos); + } + } + } + } + + return llama_sbatch(batch, hparams.n_embd, true, logits_all); +} + +llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + GGML_UNUSED(embd_pooled); + return sbatch.split_simple(n_ubatch); +} + +bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) { + bool res = true; + + res = res & kv_base->find_slot(batch); + res = res & kv_swa ->find_slot(batch); + + return res; +} + +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); +} + +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} + +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base() const { + return kv_base.get(); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const { + return kv_swa.get(); +} + // // llama_kv_cache_recurrent // @@ -1373,19 +1727,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_type type_k, ggml_type type_v, bool offload, - uint32_t kv_size) : hparams(model.hparams) { + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; - LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", - __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); head = 0; size = kv_size; used = 0; - this->type_k = type_k; - this->type_v = type_v; - cells.clear(); cells.resize(kv_size); @@ -1623,8 +1975,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { } } -void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { return; } @@ -1647,7 +1999,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_ if (tail_id >= 0) { kv_cell & cell = cells[tail_id]; if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + cell.pos += shift; } } } @@ -1683,8 +2035,24 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_ } } +llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { + llama_pos result = std::numeric_limits::max(); + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::min(result, cells[i].pos); + } + } + + if (result == std::numeric_limits::max()) { + result = -1; + } + + return result; +} + llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = 0; + llama_pos result = -1; for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id)) { @@ -1707,8 +2075,8 @@ void llama_kv_cache_recurrent::commit() { pending.ranges.clear(); } -bool llama_kv_cache_recurrent::update(llama_context & lctx) { - GGML_UNUSED(lctx); +bool llama_kv_cache_recurrent::update(llama_context & ctx) { + GGML_UNUSED(ctx); return false; } @@ -1769,7 +2137,7 @@ bool llama_kv_cache_recurrent::find_slot( if (seq_id < 0 || (uint32_t) seq_id >= size) { // too big seq_id // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); return false; } if (j > 0) { @@ -1912,29 +2280,6 @@ bool llama_kv_cache_recurrent::find_slot( return n >= n_seqs; } -int32_t llama_kv_cache_recurrent::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_recurrent::get_used_cells() const { - return used; -} - -llama_pos llama_kv_cache_recurrent::get_pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); - } - - return pos_max; -} - bool llama_kv_cache_recurrent::get_can_shift() const { return false; } @@ -2063,6 +2408,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq io.read_to(&cell_count, sizeof(cell_count)); bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); res = res && state_read_data(io, cell_count); @@ -2391,104 +2737,3 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } - -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) { - llama_kv_cache_view result = { - /*.n_cells = */ 0, - /*.n_seq_max = */ n_seq_max, - /*.token_count = */ 0, - /*.used_cells = */ kv.get_used_cells(), - /*.max_contiguous = */ 0, - /*.max_contiguous_idx = */ -1, - /*.cells = */ nullptr, - /*.cells_sequences = */ nullptr, - }; - - return result; -} - -void llama_kv_cache_view_free(llama_kv_cache_view * view) { - if (view->cells != nullptr) { - free(view->cells); - view->cells = nullptr; - } - if (view->cells_sequences != nullptr) { - free(view->cells_sequences); - view->cells_sequences = nullptr; - } -} - -void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) { - // TODO: rework this in the future, for now quick hack - const llama_kv_cache_unified * kvu = dynamic_cast(kv); - if (kvu == nullptr) { - LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__); - return; - } - - if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) { - view->n_cells = int32_t(kvu->size); - void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells); - GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); - view->cells = (llama_kv_cache_view_cell *)p; - p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells); - GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); - view->cells_sequences = (llama_seq_id *)p; - } - - const std::vector & kv_cells = kvu->cells; - llama_kv_cache_view_cell * c_curr = view->cells; - llama_seq_id * cs_curr = view->cells_sequences; - int32_t used_cells = 0; - int32_t token_count = 0; - int32_t curr_contig_idx = -1; - uint32_t max_contig = 0; - int32_t max_contig_idx = -1; - - for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) { - const size_t curr_size = kv_cells[i].seq_id.size(); - token_count += curr_size; - c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; - - if (curr_size > 0) { - if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) { - max_contig = i - curr_contig_idx; - max_contig_idx = curr_contig_idx; - } - curr_contig_idx = -1; - } else if (curr_contig_idx < 0) { - curr_contig_idx = i; - } - - int seq_idx = 0; - for (const llama_seq_id it : kv_cells[i].seq_id) { - if (seq_idx >= view->n_seq_max) { - break; - } - cs_curr[seq_idx] = it; - seq_idx++; - } - if (seq_idx != 0) { - used_cells++; - } - for (; seq_idx < view->n_seq_max; seq_idx++) { - cs_curr[seq_idx] = -1; - } - } - if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) { - max_contig_idx = curr_contig_idx; - max_contig = kv_cells.size() - curr_contig_idx; - } - view->max_contiguous = max_contig; - view->max_contiguous_idx = max_contig_idx; - view->token_count = token_count; - view->used_cells = used_cells; - if (uint32_t(used_cells) != kvu->used) { - LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, kvu->used, used_cells); - } -} diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index e83e12c0..ce6261e4 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -4,10 +4,12 @@ #include "llama-io.h" #include "llama-graph.h" #include "llama-memory.h" +#include "llama-kv-cells.h" #include "ggml-cpp.h" #include +#include #include struct llama_cparams; @@ -34,12 +36,16 @@ struct llama_kv_cache : public llama_memory_i { virtual void defrag_sched(float thold) = 0; // simulate full cache, used for allocating worst-case compute buffers + // TODO: remove virtual void set_full() = 0; // // batch processing // + // ============================================================================================================= + // TODO: refactor and simplify this [TAG: KV_API] + virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; // different KV caches require different batch splitting strategies @@ -48,11 +54,10 @@ struct llama_kv_cache : public llama_memory_i { // find an empty slot of size "n_tokens" in the cache virtual bool find_slot(const llama_ubatch & batch) = 0; + // ============================================================================================================= + // getters - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache - virtual llama_pos get_pos_max() const = 0; - virtual bool get_can_shift() const = 0; + virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } @@ -87,38 +92,25 @@ private: // llama_kv_cache_unified // -// TODO: add notion of max sequences class llama_kv_cache_unified : public llama_kv_cache { public: - struct kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const kv_cell & other) const { - return seq_id == other.seq_id; - } - }; - static uint32_t get_padding(const llama_cparams & cparams); + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_unified( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - bool offload, - uint32_t kv_size, - uint32_t padding); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type); ~llama_kv_cache_unified() = default; @@ -130,10 +122,11 @@ public: bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; // @@ -150,7 +143,6 @@ public: void set_full() override; llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; // updates the cache head @@ -158,50 +150,94 @@ public: // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; - bool get_can_shift() const override; // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) - uint32_t size = 0; // total number of cells, shared across all sequences - uint32_t used = 0; // used cells (i.e. at least one seq_id) + // + // llama_kv_cache_unified specific API + // - // computed before each graph build - uint32_t n = 0; + uint32_t get_n() const; + uint32_t get_size() const; - std::vector cells; + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; - std::vector k_l; // per layer - std::vector v_l; + // store k_cur and v_cur in the cache based on the current head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax); + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: const llama_model & model; const llama_hparams & hparams; - bool has_shift = false; - bool do_defrag = false; + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + ggml_tensor * k; + ggml_tensor * v; + }; + + bool do_defrag = false; bool v_trans = true; // the value tensor is transposed - bool can_shift = false; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + + // computed before each graph build + // TODO: cells should start to maintain this value dynamically based on the edits + uint32_t n = 0; + + const uint32_t n_seq_max = 1; // required padding - uint32_t padding = 1; + const uint32_t n_pad = 1; - ggml_type type_k = GGML_TYPE_F16; - ggml_type type_v = GGML_TYPE_F16; + // SWA + const uint32_t n_swa = 0; + + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; std::vector bufs; + llama_kv_cells_unified cells; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // recovery information used to restore the KV cells to their original state in case of a failure + // TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation + // to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API] + struct { + void clear() { + states.clear(); + } + + struct state { + uint32_t i; + + llama_kv_cells_unified cells; + }; + + // stack with the partial states before each ubatch + std::vector states; + } recovery; + // defrag struct { std::vector ids; @@ -210,25 +246,13 @@ private: // return true if cells have been moved bool defrag_prepare(int32_t n_max_nodes); - // commit/restore cache - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; - - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; - - // find how many cells are currently in use - uint32_t cell_max() const; - size_t total_size() const; size_t size_k_bytes() const; size_t size_v_bytes() const; + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, @@ -255,6 +279,100 @@ private: bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; +// +// llama_kv_cache_unified_iswa +// + +// utilizes two instances of llama_kv_cache_unified +// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers +// upon successful commit, the SWA cache removes old tokens outside the n_swa window + +class llama_kv_cache_unified_iswa : public llama_kv_cache { +public: + llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_batch, + uint32_t n_pad); + + ~llama_kv_cache_unified_iswa() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & ctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + + bool find_slot(const llama_ubatch & batch) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified_iswa specific API + // + + llama_kv_cache_unified * get_kv_base() const; + llama_kv_cache_unified * get_kv_swa () const; + +private: + const llama_hparams & hparams; + + bool do_prune = true; + + struct { + struct entry { + llama_pos pmin; + llama_pos pmax; + }; + + void clear() { + pos.clear(); + } + + // used to perform SWA pruning of old tokens + std::unordered_map pos; + } pending; + + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; + // // llama_kv_cache_recurrent // @@ -286,7 +404,8 @@ public: ggml_type type_k, ggml_type type_v, bool offload, - uint32_t kv_size); + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; @@ -298,10 +417,11 @@ public: bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; // @@ -311,24 +431,17 @@ public: void restore() override; void commit() override; - bool update(llama_context & lctx) override; + bool update(llama_context & ctx) override; void defrag_sched(float thold) override; void set_full() override; llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; bool find_slot(const llama_ubatch & batch) override; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; - bool get_can_shift() const override; // TODO: temporary methods - they are not really const as they do const_cast<>, fix this @@ -368,8 +481,7 @@ private: std::vector ranges; } pending; - ggml_type type_k = GGML_TYPE_F16; - ggml_type type_v = GGML_TYPE_F16; + const uint32_t n_seq_max = 1; std::vector ctxs; std::vector bufs; @@ -388,12 +500,3 @@ private: bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; - - -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max); - -void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv); diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h new file mode 100644 index 00000000..dbbd03fc --- /dev/null +++ b/examples/talk-llama/llama-kv-cells.h @@ -0,0 +1,379 @@ +#pragma once + +#include "llama.h" +#include "llama-cparams.h" + +#include +#include +#include +#include + +// meta information about KV cells that can be part of multiple sequences at the same time +// TODO: add unit tests +class llama_kv_cells_unified { +public: + void reset() { + for (uint32_t i = 0; i < pos.size(); ++i) { + pos[i] = -1; + shift[i] = 0; + seq[i].reset(); + } + + has_shift = false; + + used.clear(); + + for (uint32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + seq_pos[s].clear(); + } + } + + void reset_shift() { + has_shift = false; + + for (uint32_t i = 0; i < shift.size(); ++i) { + shift[i] = 0; + } + } + + uint32_t size() const { + return pos.size(); + } + + void resize(uint32_t n) { + pos.resize(n); + shift.resize(n); + seq.resize(n); + + reset(); + } + + bool is_empty(uint32_t i) const { + assert(i < pos.size()); + assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); + + return pos[i] == -1; + } + + uint32_t get_used() const { + return used.size(); + } + + // the index of the first cell that is used + // return 0 if no cells are used + uint32_t used_min() const { + return used.empty() ? 0 : *used.begin(); + } + + // the index of the last cell that is used + 1 + // return 0 if no cells are used + uint32_t used_max_p1() const { +#if 0 + if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin()); + if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin()); + if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin()); +#endif + + return used.empty() ? 0 : *used.rbegin() + 1; + } + + bool get_has_shift() const { + return has_shift; + } + + // move cell isrc to idst (used during defrag) + void mv(uint32_t isrc, uint32_t idst) { + assert(isrc < pos.size()); + assert(idst < pos.size()); + + pos [idst] = pos [isrc]; + shift[idst] = shift[isrc]; + seq [idst] = seq [isrc]; + + pos [isrc] = -1; + shift[isrc] = 0; + seq [isrc].reset(); + + used.erase (isrc); + used.insert(idst); + } + + // copy the state of cells [i, i + n) (used for save/restore the state of the cells) + llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + assert(i + n <= pos.size()); + + llama_kv_cells_unified res; + + res.resize(n); + + for (uint32_t j = 0; j < n; ++j) { + res.pos[j] = pos[i + j]; + res.seq[j] = seq[i + j]; + + assert(shift[i + j] == 0); + } + + return res; + } + + // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) + void set(uint32_t i, const llama_kv_cells_unified & other) { + assert(i + other.pos.size() <= pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + if (pos[i + j] == -1 && other.pos[j] != -1) { + used.insert(i + j); + } + + if (pos[i + j] != -1 && other.pos[j] == -1) { + used.erase(i + j); + } + + if (pos[i + j] != -1) { + seq_pos_rm(i + j); + } + + pos[i + j] = other.pos[j]; + seq[i + j] = other.seq[j]; + + if (pos[i + j] != -1) { + seq_pos_add(i + j); + } + + assert(shift[i + j] == 0); + } + } + + // note: call only if the cell has seq_id + // return true if the cell becomes empty + bool seq_rm(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(seq[i].test(seq_id)); + assert(pos[i] != -1); + assert(seq_id >= 0); + + seq[i].reset(seq_id); + seq_pos[seq_id].erase(pos[i]); + + if (seq[i].none()) { + pos[i] = -1; + + used.erase(i); + + return true; + } + + return false; + } + + // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) + bool seq_keep(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + + if (seq[i].test(seq_id)) { + seq_pos_rm(i); + seq[i].reset(); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + + return false; + } + + if (seq[i].any()) { + seq_pos_rm(i); + seq[i].reset(); + + pos[i] = -1; + + used.erase(i); + + return true; + } + + assert(pos[i] == -1); + + return false; + } + + bool seq_has(uint32_t i, llama_seq_id seq_id) const { + assert(i < pos.size()); + assert(seq_id >= 0); + + return seq[i].test(seq_id); + } + + // note: call only if the cell is not empty and the seq_id is not in the cell + void seq_add(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(pos[i] != -1); + assert(!seq[i].test(seq_id)); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + } + + // the minimum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_min(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].begin(); + } + + // the maximum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_max(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_PARALLEL_SEQUENCES); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].rbegin(); + } + + // note: call only if the cell is not empty + llama_pos pos_get(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return pos[i]; + } + + // note: call only if the cell is not empty + llama_pos get_shift(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return shift[i]; + } + + // check if a cell is not empty and its position is within [p0, p1) + bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { + assert(i < pos.size()); + + return pos[i] >= p0 && pos[i] < p1; + } + + // set the position of an empty cell + // does not modify "has_shift" + // note: call only if the cell is empty + void pos_set(uint32_t i, llama_pos p) { + assert(i < pos.size()); + assert(pos[i] == -1); + + pos[i] = p; + + used.insert(i); + } + + // pos[i] = pos[i] + d + // sets "has_shift" to true + // note: call only if the cell is not empty + bool pos_add(uint32_t i, llama_pos d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] += d; + shift[i] += d; + + seq_pos_add(i); + + has_shift = true; + + if (pos[i] < 0) { + seq_pos_rm(i); + + seq[i].reset(); + pos[i] = -1; + + used.erase(i); + + return true; + } + + return false; + } + + // pos[i] = pos[i] / d + // sets "has_shift" to true + // note: call only if the cell is not empty + void pos_div(uint32_t i, int d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + const llama_pos p_old = pos[i]; + + seq_pos_rm(i); + + pos[i] /= d; + shift[i] += p_old - pos[i]; + + seq_pos_add(i); + + has_shift = true; + } + +private: + bool has_shift = false; + + // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) + std::set used; + + std::vector pos; + + // this array accumulates any applied shifts to the pos array since the last reset_shift() call + // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: + // + // cells.pos_add(x, shift_x); + // cells.pos_div(y, shift_y); + // ... + // + // if (cells.has_shift()) { + // for (int i = 0; i < n; ++i) { + // auto shift_i = cells.get_shift(i); + // ... + // } + // cells.reset_shift(); + // } + // + std::vector shift; + + using bits_t = std::bitset; + + // the bitset seq[i] tells us which sequences are currently occupying the i-th cell + std::vector seq; + + // the set seq_pos[s] tells us which positions are currently present for sequence s + // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache + std::set seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES]; + + // helper functions for updating `seq_pos`, once cell at a time: + + // remove cell i + void seq_pos_rm(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].erase(pos[i]); + } + } + } + + // add cell i + void seq_pos_add(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) { + if (seq[i].test(s)) { + seq_pos[s].insert(pos[i]); + } + } + } +}; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index c7412d59..a2d25043 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -7,8 +7,8 @@ struct llama_memory_params { ggml_type type_k; ggml_type type_v; - // parameters for other types of memory - // ... + // use full-size SWA cache + bool swa_full; }; // general concept of LLM memory @@ -22,9 +22,10 @@ public: virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; virtual void seq_keep(llama_seq_id seq_id) = 0; - virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; virtual bool get_can_edit() const = 0; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 7fd094b6..e99f5309 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { GGML_ASSERT(hparams.n_expert_used == 0); } - // zero-out the array hparams std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); + + std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -571,9 +574,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full - hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later + + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full switch (hparams.n_expert) { case 16: type = LLM_TYPE_17B_16E; break; @@ -852,22 +856,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } - // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 - if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) { - // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct - hparams.n_swa = 2047; - } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { - // default value for Phi-3-mini-128k-instruct - // note: this seems incorrect because the window is bigger than the train context? - hparams.n_swa = 262144; - } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { - // default value for Phi-3-medium-128k-instruct - // note: this seems incorrect because the window is equal to the train context? - hparams.n_swa = 131072; - } - bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (!found_swa && hparams.n_swa == 0) { - throw std::runtime_error("invalid value for sliding_window"); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + hparams.n_swa = 0; + hparams.set_swa_pattern(1); } } break; case LLM_ARCH_PHIMOE: @@ -937,8 +936,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA2: { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; // default value of gemma 2 - hparams.n_swa_pattern = 2; + hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -955,7 +955,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { - hparams.n_swa_pattern = 6; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; @@ -1039,7 +1040,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_COHERE2: { - hparams.n_swa_pattern = 4; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(4); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -2487,7 +2489,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4321,7 +4327,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -4489,7 +4495,17 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const { return it->second; } -ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const { +float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; +} + +float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; +} + +ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + // choose long/short freq factors based on the context size if (layers[il].rope_freqs != nullptr) { return layers[il].rope_freqs; @@ -4517,22 +4533,13 @@ struct llm_build_llama : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - // temperature tuning - ggml_tensor * inp_attn_scale = nullptr; - if (arch == LLM_ARCH_LLAMA4) { - inp_attn_scale = build_inp_attn_scale(); - } - auto * inp_attn = build_attn_inp_kv_unified(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - bool use_rope = arch == LLM_ARCH_LLAMA4 - ? (il + 1) % hparams.n_no_rope_layer_step != 0 - : true; - // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -4542,7 +4549,169 @@ struct llm_build_llama : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_llama_iswa : public llm_graph_context { + llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + inp_attn_scale = build_inp_attn_scale(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4590,7 +4759,7 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { + if (use_rope && hparams.use_kq_norm) { // Llama4TextL2Norm Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); @@ -4616,7 +4785,6 @@ struct llm_build_llama : public llm_graph_context { // feed-forward network (non-MoE) if (model.layers[il].ffn_gate_inp == nullptr) { - cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); @@ -4629,9 +4797,7 @@ struct llm_build_llama : public llm_graph_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); - - } else if (arch == LLM_ARCH_LLAMA4) { - // llama4 MoE + } else { ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); @@ -4660,26 +4826,6 @@ struct llm_build_llama : public llm_graph_context { cur = ggml_add(ctx0, moe_out, shexp_out); cb(cur, "ffn_moe_out_merged", il); - - } else { - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -4753,7 +4899,7 @@ struct llm_build_deci : public llm_graph_context { } else if (n_head > 0) { // self-attention // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -7202,6 +7348,7 @@ struct llm_build_phi2 : public llm_graph_context { } }; +template struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -7217,7 +7364,14 @@ struct llm_build_phi3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } for (int il = 0; il < n_layer; ++il) { auto * residual = inpL; @@ -7225,7 +7379,7 @@ struct llm_build_phi3 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); ggml_tensor* attn_norm_output = build_norm(inpL, model.layers[il].attn_norm, @@ -7977,7 +8131,7 @@ struct llm_build_minicpm3 : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // norm cur = build_norm(inpL, @@ -8277,8 +8431,8 @@ struct llm_build_gemma : public llm_graph_context { } }; -struct llm_build_gemma2 : public llm_graph_context { - llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_gemma2_iswa : public llm_graph_context { + llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -8292,7 +8446,7 @@ struct llm_build_gemma2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8414,8 +8568,8 @@ struct llm_build_gemma2 : public llm_graph_context { } }; -struct llm_build_gemma3 : public llm_graph_context { - llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_gemma3_iswa : public llm_graph_context { + llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -8433,13 +8587,11 @@ struct llm_build_gemma3 : public llm_graph_context { ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { - const bool is_swa = hparams.is_swa(il); - - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -9016,8 +9168,8 @@ struct llm_build_command_r : public llm_graph_context { } }; -struct llm_build_cohere2 : public llm_graph_context { - llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_cohere2_iswa : public llm_graph_context { + llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9032,7 +9184,7 @@ struct llm_build_cohere2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { const bool is_swa = hparams.is_swa(il); @@ -9045,7 +9197,7 @@ struct llm_build_cohere2 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -9983,7 +10135,7 @@ struct llm_build_deepseek : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11347,7 +11499,7 @@ struct llm_build_exaone : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -12263,7 +12415,7 @@ struct llm_build_granite : public llm_graph_context { Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); if (use_rope) { - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12916,7 +13068,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -13044,6 +13196,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_WAVTOKENIZER_DEC: { res = nullptr; } break; @@ -13058,7 +13211,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max)); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); } break; default: { @@ -13068,14 +13222,36 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - res = new llama_kv_cache_unified( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - cparams.n_ctx, - padding); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.n_ctx, + cparams.n_seq_max, + cparams.n_batch, + padding); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } @@ -13090,11 +13266,14 @@ llm_graph_result_ptr llama_model::build_graph( switch (arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA4: case LLM_ARCH_MINICPM: { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_LLAMA4: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params, gf); @@ -13169,7 +13348,11 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: { - llm = std::make_unique(*this, params, gf); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params, gf); + } else { + llm = std::make_unique>(*this, params, gf); + } } break; case LLM_ARCH_PLAMO: { @@ -13201,11 +13384,11 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_GEMMA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_STARCODER2: { @@ -13225,7 +13408,7 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_COHERE2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_DBRX: { diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 6bdec263..cbea2cb3 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -398,7 +398,10 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; - ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const; + float get_rope_freq_base (const llama_cparams & cparams, int il) const; + float get_rope_freq_scale(const llama_cparams & cparams, int il) const; + + ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index 804b11e0..bfbf5fa2 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -798,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d } // if we have enough values the operation was a success - if (filtered_tokens.size() >= ctx->min_keep) { + if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); cur_p->size = filtered_tokens.size(); min_p_applied = true; @@ -909,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > ctx->p && i >= ctx->min_keep - 1) { + if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { last_idx = i + 1; break; } diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 9389ca80..d5a036a8 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session { } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX}); // at the beginning tokenization score is zero tokenization_results[0] = { vocab.token_unk(), 0, 0 }; @@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session { const double challenger_score = current_best.score_sum + token_score; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { token_id, input_offset, challenger_score }; current_champ = challenger; } } @@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session { prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score }; + struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score }; current_champ = challenger; } } @@ -1007,7 +1007,7 @@ private: struct best_tokenization { llama_token token_id; size_t input_offset; - float score_sum; + double score_sum; }; struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 99e5fba2..01762bea 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -361,10 +361,11 @@ extern "C" { // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool embeddings; // if true, extract embeddings (together with logits) - bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - bool no_perf; // whether to measure performance timings - bool op_offload; // whether to offload host tensor operations to device + bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // use flash attention [EXPERIMENTAL] + bool no_perf; // measure performance timings + bool op_offload; // offload host tensor operations to device + bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) }; // model quantization parameters @@ -470,6 +471,7 @@ extern "C" { LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); + LLAMA_API size_t llama_max_parallel_sequences(void); LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); @@ -607,71 +609,14 @@ extern "C" { // KV cache // - // TODO: start using struct llama_kv_cache - - // Information associated with an individual cell in the KV cache view. - struct llama_kv_cache_view_cell { - // The position for this cell. Takes KV cache shifts into account. - // May be negative if the cell is not populated. - llama_pos pos; - }; - - // An updateable view of the KV cache. - struct llama_kv_cache_view { - // Number of KV cache cells. This will be the same as the context size. - int32_t n_cells; - - // Maximum number of sequences that can exist in a cell. It's not an error - // if there are more sequences in a cell than this value, however they will - // not be visible in the view cells_sequences. - int32_t n_seq_max; - - // Number of tokens in the cache. For example, if there are two populated - // cells, the first with 1 sequence id in it and the second with 2 sequence - // ids then you'll have 3 tokens. - int32_t token_count; - - // Number of populated cache cells. - int32_t used_cells; - - // Maximum contiguous empty slots in the cache. - int32_t max_contiguous; - - // Index to the start of the max_contiguous slot range. Can be negative - // when cache is full. - int32_t max_contiguous_idx; - - // Information for an individual cell. - struct llama_kv_cache_view_cell * cells; - - // The sequences for each cell. There will be n_seq_max items per cell. - llama_seq_id * cells_sequences; - }; - - // Create an empty KV cache view. (use only for debugging purposes) - LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max); - - // Free a KV cache view. (use only for debugging purposes) - LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); - - // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) - // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx) - LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); - - /// - // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times - LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx); - - DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx), - "use llama_kv_self_n_tokens instead"); + DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) - LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx); - - DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx), - "use llama_kv_self_used_cells instead"); + DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_self_clear( @@ -730,10 +675,18 @@ extern "C" { llama_pos p1, int d); + // Returns the smallest position present in the KV cache for the specified sequence + // This is typically non-zero only for SWA caches + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_kv_self_seq_pos_min( + struct llama_context * ctx, + llama_seq_id seq_id); + // Returns the largest position present in the KV cache for the specified sequence + // Return -1 if the sequence is empty LLAMA_API llama_pos llama_kv_self_seq_pos_max( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id); // Defragment the KV cache // This will be applied: @@ -747,61 +700,6 @@ extern "C" { // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) LLAMA_API void llama_kv_self_update(struct llama_context * ctx); - DEPRECATED(LLAMA_API void llama_kv_cache_clear( - struct llama_context * ctx), - "use llama_kv_self_clear instead"); - - DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "use llama_kv_self_seq_rm instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "use llama_kv_self_seq_cp instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_kv_self_seq_keep instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta), - "use llama_kv_self_seq_add instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d), - "use llama_kv_self_seq_div instead"); - - DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_kv_self_seq_pos_max instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx), - "use llama_kv_self_defrag instead"); - - DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx), - "use llama_kv_self_can_shift instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx), - "use llama_kv_self_update instead"); - - // // State / sessions // @@ -943,9 +841,12 @@ extern "C" { // Requires KV cache. // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. - // 0 - success - // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error. the KV cache state is restored to the state before this call + // Upon non-zero return values, the KV cache state is restored to the state before this call + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // 2 - aborted + // -1 - invalid input batch + // < -1 - error LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch);