mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-16 13:58:09 +00:00
talk-llama : sync llama.cpp
ggml-ci
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
#include "llama-kv-cache-unified.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
@ -128,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear() {
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
cells.reset();
|
||||
|
||||
head = 0;
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -149,12 +152,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
if (seq_id >= 0) {
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// match any sequence
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.rm(i);
|
||||
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
return std::make_unique<llama_kv_cache_unified_state>(
|
||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
std::vector<uint32_t> res;
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
{
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
|
||||
if (!do_defrag && thold > 0.0f) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::ubatch_heads res;
|
||||
|
||||
struct state {
|
||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||
@ -359,12 +410,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
||||
bool updated = false;
|
||||
|
||||
auto * sched = lctx.get_sched();
|
||||
auto * sched = lctx->get_sched();
|
||||
|
||||
if (cells.get_has_shift()) {
|
||||
if (do_shift) {
|
||||
if (!get_can_shift()) {
|
||||
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
||||
}
|
||||
@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * gf = lctx.graph_init();
|
||||
auto * gf = lctx->graph_init();
|
||||
|
||||
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
||||
auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
||||
return updated;
|
||||
@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
@ -401,56 +452,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
cells.reset_shift();
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
if (defrag_prepare(lctx.graph_max_nodes())) {
|
||||
ggml_backend_sched_reset(sched);
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
|
||||
auto * gf = lctx.graph_init();
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
assert(dinfo.ids[i] <= n_kv);
|
||||
|
||||
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||
return updated;
|
||||
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells.mv(i, dinfo.ids[i]);
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
// reset the head so we can find the first free slot during the next ubatch
|
||||
head = 0;
|
||||
}
|
||||
|
||||
do_defrag = false;
|
||||
ggml_backend_sched_reset(sched);
|
||||
|
||||
auto * gf = lctx->graph_init();
|
||||
|
||||
auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
res->set_inputs(nullptr);
|
||||
|
||||
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||
return updated;
|
||||
}
|
||||
|
||||
updated = true;
|
||||
}
|
||||
|
||||
return updated;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
const auto n_kv = cells.used_max_p1();
|
||||
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
do_defrag = true;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
@ -597,6 +647,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
return cells.get_has_shift();
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
}
|
||||
@ -890,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
//GGML_ASSERT(kv_self->size == n_ctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf) const {
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const {
|
||||
auto res = std::make_unique<llm_graph_result>();
|
||||
|
||||
const auto & ids = defrag_info.ids;
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
||||
|
||||
// determine which KV cells to move where
|
||||
//
|
||||
// cell i moves to ids[i]
|
||||
//
|
||||
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
||||
//
|
||||
auto & ids = defrag_info.ids;
|
||||
defrag_info res;
|
||||
auto & ids = res.ids;
|
||||
|
||||
ids.clear();
|
||||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
// this cell goes to (i0 + nf)
|
||||
ids[i1] = i0 + nf;
|
||||
|
||||
// move the cell meta data
|
||||
cells.mv(i1, i0 + nf);
|
||||
|
||||
head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
cont = true;
|
||||
@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
return false;
|
||||
return {};
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
||||
|
||||
return true;
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
|
||||
if (!res) {
|
||||
if (seq_id == -1) {
|
||||
clear();
|
||||
clear(true);
|
||||
} else {
|
||||
seq_rm(seq_id, -1, -1);
|
||||
}
|
||||
@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
return false;
|
||||
}
|
||||
|
||||
clear();
|
||||
clear(true);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv) : status(status), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
kv(kv),
|
||||
sbatch(std::move(sbatch)),
|
||||
heads(std::move(heads)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
||||
if (!do_shift && dinfo.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
llama_kv_cache_unified::ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
||||
|
||||
@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
|
||||
bool llama_kv_cache_unified_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
if (ubatches.empty()) {
|
||||
kv->update(lctx, do_shift, dinfo);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
||||
|
||||
n_kv = kv->get_n_kv();
|
||||
|
Reference in New Issue
Block a user