mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-22 16:38:58 +00:00
talk-llama : sync llama.cpp
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, x86, 0.3.29, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, x64_64, 0.3.29, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.4.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-intel.Dockerfile platform:linux/amd64 tag:main-intel]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, x86, 0.3.29, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, x64_64, 0.3.29, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.4.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-intel.Dockerfile platform:linux/amd64 tag:main-intel]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
ggml-ci
This commit is contained in:
@ -127,6 +127,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
||||
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
@ -307,24 +310,27 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled,
|
||||
bool logits_all) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (sbatch.n_tokens > 0) {
|
||||
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
||||
}
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (sbatch.n_tokens > 0) {
|
||||
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
||||
}
|
||||
|
||||
auto heads = prepare(ubatches);
|
||||
if (heads.empty()) {
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
auto heads = prepare(ubatches);
|
||||
if (heads.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(
|
||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||
return std::make_unique<llama_kv_cache_unified_state>(
|
||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||
@ -512,43 +518,68 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
head_cur = 0;
|
||||
}
|
||||
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return -1;
|
||||
}
|
||||
|
||||
//#define FIND_SLOT_DEBUG 1
|
||||
#if FIND_SLOT_DEBUG
|
||||
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", cells.used_max_p1(), cells.get_used(), head, n_swa);
|
||||
if (debug > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
||||
|
||||
// for debugging
|
||||
{
|
||||
std::string ss;
|
||||
if (n_swa > 0) {
|
||||
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
||||
std::string ss;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.is_empty(i)) {
|
||||
ss += '.';
|
||||
} else {
|
||||
ss += std::to_string(cells.seq_get(i));
|
||||
assert(cells.seq_count(i) >= 1);
|
||||
|
||||
if (cells.seq_count(i) == 1) {
|
||||
ss += std::to_string(cells.seq_get(i));
|
||||
} else {
|
||||
ss += 'M';
|
||||
}
|
||||
}
|
||||
if (i%256 == 255) {
|
||||
ss += " *";
|
||||
ss += '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
|
||||
}
|
||||
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
if (cells.seq_pos_min(s) < 0) {
|
||||
continue;
|
||||
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("kv_cells: n_swa = %4d, min[%d] = %5d, max[%d] = %5d\n", n_swa, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
|
||||
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
||||
std::string ss;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
std::string cur;
|
||||
if (cells.is_empty(i)) {
|
||||
cur = '.';
|
||||
} else {
|
||||
cur = std::to_string(cells.pos_get(i));
|
||||
}
|
||||
const int n = cur.size();
|
||||
for (int j = 0; j < 5 - n; ++j) {
|
||||
cur += ' ';
|
||||
}
|
||||
ss += cur;
|
||||
if (i%256 == 255) {
|
||||
ss += " *";
|
||||
}
|
||||
if (i%64 == 63) {
|
||||
ss += '\n';
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
|
||||
}
|
||||
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (cells.seq_pos_min(s) < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t n_tested = 0;
|
||||
|
||||
@ -559,21 +590,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
continue;
|
||||
}
|
||||
|
||||
// keep track of what the minimum sequence positions would be if we accept the ubatch
|
||||
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||
seq_pos_min[s] = cells.seq_pos_min(s);
|
||||
}
|
||||
|
||||
bool found = true;
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_pos pos = ubatch.pos[i];
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
//const llama_pos pos = ubatch.pos[i];
|
||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
// can we use this cell? either:
|
||||
// - the cell is empty
|
||||
// - the cell is occupied only by one sequence:
|
||||
// - mask causally, if the sequence is the same as the one we are inserting
|
||||
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||
// - mask SWA, using current max pos for that sequence in the cache
|
||||
// always insert in the cell with minimum pos
|
||||
bool can_use = cells.is_empty(head_cur + i);
|
||||
@ -581,21 +606,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
||||
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
||||
|
||||
// causal mask
|
||||
if (cells.seq_has(head_cur + i, seq_id)) {
|
||||
can_use = pos_cell >= pos;
|
||||
}
|
||||
// (disabled) causal mask
|
||||
// note: it's better to purge any "future" tokens beforehand
|
||||
//if (cells.seq_has(head_cur + i, seq_id)) {
|
||||
// can_use = pos_cell >= pos;
|
||||
//}
|
||||
|
||||
if (!can_use) {
|
||||
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
||||
|
||||
// SWA mask
|
||||
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
|
||||
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||
if (pos_cell == seq_pos_min[seq_id_cell] &&
|
||||
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
seq_pos_min[seq_id_cell]++;
|
||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
can_use = true;
|
||||
}
|
||||
}
|
||||
@ -623,18 +644,58 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
if (!cells.is_empty(head_cur + i)) {
|
||||
cells.rm(head_cur + i);
|
||||
}
|
||||
if (debug > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
|
||||
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
|
||||
}
|
||||
|
||||
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
seq_pos_max_rm[s] = -1;
|
||||
}
|
||||
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
||||
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
|
||||
const uint32_t idx = s*ubatch.n_seq_tokens + j;
|
||||
|
||||
if (!cells.is_empty(head_cur + idx)) {
|
||||
assert(cells.seq_count(head_cur + idx) == 1);
|
||||
|
||||
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
|
||||
const llama_pos pos = cells.pos_get(head_cur + idx);
|
||||
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
|
||||
cells.rm(head_cur + idx);
|
||||
}
|
||||
|
||||
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
|
||||
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq_pos_max_rm[s] == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
||||
|
||||
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||
}
|
||||
}
|
||||
// move the head at the end of the slot
|
||||
head = head_cur + ubatch.n_tokens;
|
||||
}
|
||||
@ -731,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const uint32_t n_seqs = ubatch->n_seqs;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
float * data = (float *) dst->data;
|
||||
|
||||
const auto n_kv = dst->ne[0];
|
||||
const int64_t n_kv = dst->ne[0];
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
@ -752,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
||||
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
|
||||
const uint32_t idx = s*n_seq_tokens + j;
|
||||
|
||||
const llama_pos p1 = ubatch->pos[idx];
|
||||
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
float f = 0.0f;
|
||||
@ -787,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
f = -INFINITY;
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1447,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
seq_rm(dest_seq_id, -1, -1);
|
||||
|
||||
llama_sbatch sbatch;
|
||||
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
|
||||
batch.n_tokens = cell_count;
|
||||
ubatch.n_tokens = cell_count;
|
||||
ubatch.n_seq_tokens = cell_count;
|
||||
ubatch.n_seqs = 1;
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@ -1469,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
batch.n_seq_id[i] = n_seq_id;
|
||||
batch.seq_id[i] = &dest_seq_id;
|
||||
ubatch.pos[i] = pos;
|
||||
ubatch.n_seq_id[i] = n_seq_id;
|
||||
ubatch.seq_id[i] = &dest_seq_id;
|
||||
}
|
||||
|
||||
const auto head_cur = find_slot(batch);
|
||||
const auto head_cur = find_slot(ubatch);
|
||||
if (head_cur < 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
apply_ubatch(head_cur, batch);
|
||||
apply_ubatch(head_cur, ubatch);
|
||||
|
||||
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
||||
head = head_cur;
|
||||
@ -1488,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head_cur + cell_count <= cells.size());
|
||||
GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]);
|
||||
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
|
||||
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
|
||||
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
|
||||
} else {
|
||||
@ -1674,7 +1739,7 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
||||
if (!do_shift && dinfo.empty()) {
|
||||
if (!do_shift && this->dinfo.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user