talk-llama : sync llama.cpp

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-10 10:12:44 +03:00
parent 96eaf46ec6
commit db264d6220
23 changed files with 911 additions and 437 deletions

View File

@ -2,8 +2,8 @@
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map>
#include <vector>
@ -17,13 +17,26 @@ struct llama_context;
// llama_kv_cache_unified
//
class llama_kv_cache_unified : public llama_kv_cache {
class llama_kv_cache_unified : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
using ubatch_heads = std::vector<uint32_t>;
struct defrag_info {
bool empty() const {
return ids.empty();
}
// contains information about which cell moves where:
// - cell i moves to ids[i]
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
std::vector<uint32_t> ids;
};
llama_kv_cache_unified(
const llama_model & model,
layer_filter_cb && filter,
@ -43,7 +56,19 @@ public:
// llama_memory_i
//
void clear() override;
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@ -54,24 +79,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
bool get_can_shift() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@ -83,6 +90,8 @@ public:
uint32_t get_size() const;
bool get_has_shift() const;
//
// graph_build API
//
@ -103,7 +112,9 @@ public:
// find places for the provided ubatches in the cache, returns the head locations
// return empty vector on failure
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
// return the cell position where we can insert the ubatch
// return -1 on failure to find a contiguous slot of kv cells
@ -133,8 +144,7 @@ private:
ggml_tensor * v;
};
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
bool v_trans = true; // the value tensor is transposed
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@ -160,13 +170,8 @@ private:
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
// defrag
struct {
std::vector<uint32_t> ids;
} defrag_info;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// return non-empty vector if cells have been moved
defrag_info defrag_prepare(int32_t n_max_nodes) const;
size_t total_size() const;
@ -192,7 +197,8 @@ private:
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
ggml_cgraph * gf,
const defrag_info & dinfo) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@ -203,20 +209,29 @@ private:
class llama_kv_cache_unified_state : public llama_memory_state_i {
public:
// some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors
llama_kv_cache_unified_state(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv);
// used to create a state from a batch
// used to create an update state
llama_kv_cache_unified_state(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo);
// used to create a decode state from a batch
llama_kv_cache_unified_state(
llama_memory_status status,
llama_kv_cache_unified * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads,
ubatch_heads heads,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state();
@ -253,16 +268,30 @@ public:
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
private:
const llama_memory_status status;
llama_memory_status status;
llama_kv_cache_unified * kv;
llama_context * lctx;
//
// update state
//
bool do_shift = false;
defrag_info dinfo;
//
// batch processing state
//
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<uint32_t> heads;
ubatch_heads heads;
std::vector<llama_ubatch> ubatches;
//