mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-22 16:38:58 +00:00
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, x86, 0.3.29, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, x64_64, 0.3.29, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.4.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-intel.Dockerfile platform:linux/amd64 tag:main-intel]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled
ggml-ci
148 lines
5.2 KiB
C++
148 lines
5.2 KiB
C++
#pragma once
|
|
|
|
#include "llama.h"
|
|
|
|
#include "llama-cparams.h"
|
|
|
|
#include <array>
|
|
#include <vector>
|
|
#include <set>
|
|
#include <bitset>
|
|
#include <unordered_map>
|
|
|
|
// keep this struct lightweight
|
|
// it points to data in `llama_batch_allocr`
|
|
struct llama_ubatch {
|
|
bool equal_seqs;
|
|
// TODO: whole_seqs for embeddings?
|
|
|
|
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
|
uint32_t n_seq_tokens; // tokens per sequence set
|
|
uint32_t n_seqs; // sequence sets in the ubatch
|
|
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
|
|
|
// seq_id_unq: unique sequence ids in the ubatch
|
|
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
|
// used for extracting sequence pooled embeddings
|
|
|
|
// // size | idx | val
|
|
llama_token * token; // [n_tokens] | i | id, token
|
|
float * embd; // [n_embd, n_tokens] | i | embd
|
|
llama_pos * pos; // [n_tokens] | i | pos
|
|
int32_t * n_seq_id; // [n_tokens] | i | -
|
|
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
|
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
|
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
|
int8_t * output; // [n_tokens] | i | -
|
|
};
|
|
|
|
// a helper for sanitizing, fulfilling and splitting a batch
|
|
class llama_batch_allocr {
|
|
public:
|
|
llama_batch_allocr(uint32_t n_pos_per_embd);
|
|
|
|
// sanitize and auto-gen missing data in the input batch
|
|
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
|
bool init(
|
|
const llama_batch & batch_inp,
|
|
const llama_vocab & vocab,
|
|
const llama_memory_i * memory,
|
|
uint32_t n_embd,
|
|
bool output_all);
|
|
|
|
const llama_batch & get_batch() const;
|
|
|
|
uint32_t get_n_tokens() const;
|
|
uint32_t get_n_outputs() const;
|
|
|
|
// the array of output indices in the order they were encountered during the ubatch splitting
|
|
std::vector<int32_t> & get_out_ids();
|
|
|
|
// min/max positions of each sequence in the current ubatch
|
|
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
|
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
|
|
|
// call once before splitting the batch to reset the internal state
|
|
void split_reset();
|
|
|
|
// simple split, unknown number of sequence sets of unequal lengths
|
|
llama_ubatch split_simple(uint32_t n_ubatch);
|
|
|
|
// make ubatches of equal-length sequences sets
|
|
llama_ubatch split_equal(uint32_t n_ubatch);
|
|
|
|
// sequence-set-wise split - each ubatch contains a single sequence-set
|
|
llama_ubatch split_seq(uint32_t n_ubatch);
|
|
|
|
// a helper method for creating a well-defined ubatch of tokens
|
|
// TODO: support embeddings if needed in the future
|
|
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
|
|
|
|
private:
|
|
void clear();
|
|
|
|
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
|
|
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
|
|
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
|
|
|
|
// for debugging, start with LLAMA_BATCH_DEBUG=2
|
|
void ubatch_print(const llama_ubatch & ubatch, int debug);
|
|
|
|
llama_batch batch;
|
|
|
|
// only for debugging purposes
|
|
const llama_vocab * vocab;
|
|
|
|
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
|
|
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
|
const uint32_t n_pos_per_embd;
|
|
|
|
uint32_t n_embd;
|
|
uint32_t n_outputs;
|
|
|
|
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
|
|
std::vector<llama_pos> pos;
|
|
std::vector<int32_t> n_seq_id;
|
|
std::vector<llama_seq_id *> seq_id;
|
|
std::vector<llama_seq_id> seq_id_unq;
|
|
std::vector<int32_t> seq_idx;
|
|
std::vector<int8_t> output;
|
|
|
|
using pos_set_t = std::set<llama_pos>;
|
|
using seq_cpl_t = std::vector<bool>;
|
|
|
|
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
|
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
|
|
|
using idx_vec_t = std::vector<int32_t>;
|
|
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
|
|
|
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
|
|
|
|
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
|
|
|
|
// batch indices of the output
|
|
std::vector<int32_t> out_ids;
|
|
|
|
// used[i] indicates if token i has already been used in a previous ubatch
|
|
std::vector<bool> used;
|
|
|
|
// llama_ubatch points to this data:
|
|
struct ubatch {
|
|
std::vector<llama_token> token;
|
|
std::vector<float> embd;
|
|
std::vector<llama_pos> pos;
|
|
std::vector<int32_t> n_seq_id;
|
|
std::vector<llama_seq_id *> seq_id;
|
|
std::vector<llama_seq_id> seq_id_unq;
|
|
std::vector<int32_t> seq_idx;
|
|
std::vector<int8_t> output;
|
|
};
|
|
|
|
// current splitting state:
|
|
std::vector<ubatch> ubatches;
|
|
|
|
int debug;
|
|
};
|