Flash + language support (ref #2)

- Achieved big performance improvement + memory usage reduction
- Can now translate / transcribe different languages
This commit is contained in:
Georgi Gerganov 2022-09-28 20:46:05 +03:00
parent 154fa796dd
commit f888c2373d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 1278 additions and 139 deletions

View File

@ -30,11 +30,16 @@ samples:
# runs it on all samples in the folder "./samples": # runs it on all samples in the folder "./samples":
.PHONY: tiny.en .PHONY: tiny.en
.PHONY: tiny
.PHONY: base.en .PHONY: base.en
.PHONY: medium.en .PHONY: base
.PHONY: small.en .PHONY: small.en
.PHONY: small
.PHONY: medium.en
.PHONY: medium
.PHONY: large
tiny.en base.en medium.en small.en: main tiny.en tiny base.en base small.en small medium.en medium large: main
bash ./download-ggml-model.sh $@ bash ./download-ggml-model.sh $@
@echo "" @echo ""
@echo "===============================================" @echo "==============================================="

View File

@ -4,7 +4,8 @@ C/C++ port of [OpenAI's Whisper](https://github.com/openai/whisper) speech-to-te
- Plain C/C++ implementation without dependencies - Plain C/C++ implementation without dependencies
- ARM_NEON and AVX intrinsics support - ARM_NEON and AVX intrinsics support
- F16 support - Mixed F16 / F32 support
- Low memory usage (Flash Attention + Flash Forward)
## Usage ## Usage
@ -27,9 +28,33 @@ For a quick demo, simply run `make base.en`:
```bash ```bash
$ make base.en $ make base.en
Downloading base.en (142 MB just once) gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c
mkdir -p models g++ -pthread -O3 -std=c++11 -c main.cpp
models/ggml-base.en.bin 100%[=================================>] 141.11M 7.50MB/s in 19s g++ -o main ggml.o main.o
./main -h
usage: ./main [options]
options:
-h, --help show this help message and exit
-s SEED, --seed SEED RNG seed (default: -1)
-t N, --threads N number of threads to use during computation (default: 4)
-T N, --tokens N maximum number of tokens to generate per iteration (default: 64)
-v, --verbose verbose output
--translate translate from source language to english
-ps, --print_special print special tokens
-l LANG, --language LANG spoken language (default: en)
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
-f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav)
bash ./download-ggml-model.sh base.en
Downloading ggml model base.en ...
models/ggml-base.en.bin 100%[=====================================>] 141.11M 7.84MB/s in 18s
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
You can now use it like this:
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
=============================================== ===============================================
Running base.en on all samples in ./samples ... Running base.en on all samples in ./samples ...
@ -52,23 +77,24 @@ whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80 whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1 whisper_model_load: f16 = 1
whisper_model_load: type = 2 whisper_model_load: type = 2
whisper_model_load: mem_required = 782.00 MB whisper_model_load: mem_required = 611.00 MB
whisper_model_load: adding 1607 extra tokens whisper_model_load: adding 1607 extra tokens
whisper_model_load: ggml ctx size = 186.26 MB whisper_model_load: ggml ctx size = 163.43 MB
whisper_model_load: memory size = 45.66 MB whisper_model_load: memory size = 22.83 MB
whisper_model_load: model size = 140.54 MB whisper_model_load: model size = 140.54 MB
log_mel_spectrogram: n_sample = 176000, n_len = 1100 log_mel_spectrogram: n_sample = 176000, n_len = 1100
log_mel_spectrogram: recording length: 11.000000 s log_mel_spectrogram: recording length: 11.000000 s
main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe ...
And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country. And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country.
main: load time = 60.62 ms main: load time = 71.89 ms
main: mel time = 38.69 ms main: mel time = 36.95 ms
main: sample time = 2.36 ms main: sample time = 2.10 ms
main: encode time = 875.63 ms / 145.94 ms per layer main: encode time = 700.94 ms / 116.82 ms per layer
main: decode time = 103.17 ms main: decode time = 86.14 ms
main: total time = 1081.13 ms main: total time = 898.72 ms
``` ```
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`. The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
@ -81,13 +107,18 @@ make samples
This will download a few more audio files from Wikipedia and convert them to 16-bit WAV format via `ffmpeg`. This will download a few more audio files from Wikipedia and convert them to 16-bit WAV format via `ffmpeg`.
You can download and run the other `.en` models as follows: You can download and run the other models as follows:
``` ```
make tiny.en make tiny.en
make tiny
make base.en make base.en
make base
make small.en make small.en
make small
make medium.en make medium.en
make medium
make large
``` ```
For detailed usage instructions, run: `./main -h` For detailed usage instructions, run: `./main -h`
@ -101,10 +132,8 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
## Limitations ## Limitations
- Only `.en` models are supported
- Very basic greedy sampling scheme - always pick up the top token - Very basic greedy sampling scheme - always pick up the top token
- No timestamps - No timestamps
- English only
- Inference only - Inference only
- Runs on the CPU - Runs on the CPU
- Only mono-channel 16-bit WAV is supported - Only mono-channel 16-bit WAV is supported
@ -113,10 +142,11 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav
| Model | Disk | Mem | | Model | Disk | Mem |
| --- | --- | --- | | --- | --- | --- |
| tiny.en | 75 MB | ~600 MB | | tiny | 75 MB | ~460 MB |
| base.en | 142 MB | ~800 MB | | base | 142 MB | ~620 MB |
| small.en | 466 MB | ~1.6 GB | | small | 466 MB | ~1.3 GB |
| medium.en | 1.5 GB | ~3.5 GB | | medium | 1.5 GB | ~2.8 GB |
| large | 2.9 GB | ~4.9 GB |
## ggml format ## ggml format

View File

@ -6,7 +6,7 @@
ggml_path=$(dirname $(realpath $0)) ggml_path=$(dirname $(realpath $0))
# Whisper models # Whisper models
models=( "tiny.en" "base.en" "small.en" "medium.en" ) models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large" )
# list available models # list available models
function list_models { function list_models {

975
ggml.c

File diff suppressed because it is too large Load Diff

25
ggml.h
View File

@ -12,6 +12,7 @@ extern "C" {
#define GGML_MAX_NODES 4096 #define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 16 #define GGML_MAX_PARAMS 16
#define GGML_MAX_CONTEXTS 16 #define GGML_MAX_CONTEXTS 16
#define GGML_MAX_OPT 4
#ifdef __ARM_NEON #ifdef __ARM_NEON
// we use the built-in 16-bit float type // we use the built-in 16-bit float type
@ -71,6 +72,9 @@ enum ggml_op {
GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_1S,
GGML_OP_CONV_1D_2S, GGML_OP_CONV_1D_2S,
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
GGML_OP_COUNT, GGML_OP_COUNT,
}; };
@ -93,6 +97,7 @@ struct ggml_tensor {
struct ggml_tensor * grad; struct ggml_tensor * grad;
struct ggml_tensor * src0; struct ggml_tensor * src0;
struct ggml_tensor * src1; struct ggml_tensor * src1;
struct ggml_tensor * opt[GGML_MAX_OPT];
// thread scheduling // thread scheduling
int n_tasks; int n_tasks;
@ -182,14 +187,19 @@ struct ggml_tensor * ggml_new_tensor_4d(
int ne2, int ne2,
int ne3); int ne3);
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
@ -399,6 +409,21 @@ struct ggml_tensor * ggml_conv_1d_2s(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
struct ggml_tensor * k,
struct ggml_tensor * v,
bool masked);
struct ggml_tensor * ggml_flash_ff(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b0,
struct ggml_tensor * b1,
struct ggml_tensor * c0,
struct ggml_tensor * c1);
// //
// automatic differentiation // automatic differentiation
// //

334
main.cpp
View File

@ -1,5 +1,8 @@
#include "ggml.h" #include "ggml.h"
#define USE_FLASH_ATTN
#define USE_FLASH_FF
// third-party utilities // third-party utilities
// use your favorite implementations // use your favorite implementations
#define DR_WAV_IMPLEMENTATION #define DR_WAV_IMPLEMENTATION
@ -16,6 +19,7 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
// available whisper models
enum e_model { enum e_model {
MODEL_UNKNOWN, MODEL_UNKNOWN,
MODEL_TINY, MODEL_TINY,
@ -25,14 +29,116 @@ enum e_model {
MODEL_LARGE, MODEL_LARGE,
}; };
const std::map<std::string, std::pair<int, std::string>> g_lang = {
{ "en", { 0, "english", } },
{ "zh", { 1, "chinese", } },
{ "de", { 2, "german", } },
{ "es", { 3, "spanish", } },
{ "ru", { 4, "russian", } },
{ "ko", { 5, "korean", } },
{ "fr", { 6, "french", } },
{ "ja", { 7, "japanese", } },
{ "pt", { 8, "portuguese", } },
{ "tr", { 9, "turkish", } },
{ "pl", { 10, "polish", } },
{ "ca", { 11, "catalan", } },
{ "nl", { 12, "dutch", } },
{ "ar", { 13, "arabic", } },
{ "sv", { 14, "swedish", } },
{ "it", { 15, "italian", } },
{ "id", { 16, "indonesian", } },
{ "hi", { 17, "hindi", } },
{ "fi", { 18, "finnish", } },
{ "vi", { 19, "vietnamese", } },
{ "iw", { 20, "hebrew", } },
{ "uk", { 21, "ukrainian", } },
{ "el", { 22, "greek", } },
{ "ms", { 23, "malay", } },
{ "cs", { 24, "czech", } },
{ "ro", { 25, "romanian", } },
{ "da", { 26, "danish", } },
{ "hu", { 27, "hungarian", } },
{ "ta", { 28, "tamil", } },
{ "no", { 29, "norwegian", } },
{ "th", { 30, "thai", } },
{ "ur", { 31, "urdu", } },
{ "hr", { 32, "croatian", } },
{ "bg", { 33, "bulgarian", } },
{ "lt", { 34, "lithuanian", } },
{ "la", { 35, "latin", } },
{ "mi", { 36, "maori", } },
{ "ml", { 37, "malayalam", } },
{ "cy", { 38, "welsh", } },
{ "sk", { 39, "slovak", } },
{ "te", { 40, "telugu", } },
{ "fa", { 41, "persian", } },
{ "lv", { 42, "latvian", } },
{ "bn", { 43, "bengali", } },
{ "sr", { 44, "serbian", } },
{ "az", { 45, "azerbaijani", } },
{ "sl", { 46, "slovenian", } },
{ "kn", { 47, "kannada", } },
{ "et", { 48, "estonian", } },
{ "mk", { 49, "macedonian", } },
{ "br", { 50, "breton", } },
{ "eu", { 51, "basque", } },
{ "is", { 52, "icelandic", } },
{ "hy", { 53, "armenian", } },
{ "ne", { 54, "nepali", } },
{ "mn", { 55, "mongolian", } },
{ "bs", { 56, "bosnian", } },
{ "kk", { 57, "kazakh", } },
{ "sq", { 58, "albanian", } },
{ "sw", { 59, "swahili", } },
{ "gl", { 60, "galician", } },
{ "mr", { 61, "marathi", } },
{ "pa", { 62, "punjabi", } },
{ "si", { 63, "sinhala", } },
{ "km", { 64, "khmer", } },
{ "sn", { 65, "shona", } },
{ "yo", { 66, "yoruba", } },
{ "so", { 67, "somali", } },
{ "af", { 68, "afrikaans", } },
{ "oc", { 69, "occitan", } },
{ "ka", { 70, "georgian", } },
{ "be", { 71, "belarusian", } },
{ "tg", { 72, "tajik", } },
{ "sd", { 73, "sindhi", } },
{ "gu", { 74, "gujarati", } },
{ "am", { 75, "amharic", } },
{ "yi", { 76, "yiddish", } },
{ "lo", { 77, "lao", } },
{ "uz", { 78, "uzbek", } },
{ "fo", { 79, "faroese", } },
{ "ht", { 80, "haitian creole", } },
{ "ps", { 81, "pashto", } },
{ "tk", { 82, "turkmen", } },
{ "nn", { 83, "nynorsk", } },
{ "mt", { 84, "maltese", } },
{ "sa", { 85, "sanskrit", } },
{ "lb", { 86, "luxembourgish", } },
{ "my", { 87, "myanmar", } },
{ "bo", { 88, "tibetan", } },
{ "tl", { 89, "tagalog", } },
{ "mg", { 90, "malagasy", } },
{ "as", { 91, "assamese", } },
{ "tt", { 92, "tatar", } },
{ "haw", { 93, "hawaiian", } },
{ "ln", { 94, "lingala", } },
{ "ha", { 95, "hausa", } },
{ "ba", { 96, "bashkir", } },
{ "jw", { 97, "javanese", } },
{ "su", { 98, "sundanese", } },
};
const size_t MB = 1024*1024; const size_t MB = 1024*1024;
const std::map<e_model, size_t> MEM_REQ_MODEL = { const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 100ull*MB }, { MODEL_TINY, 86ull*MB },
{ MODEL_BASE, 190ull*MB }, { MODEL_BASE, 165ull*MB },
{ MODEL_SMALL, 610ull*MB }, { MODEL_SMALL, 540ull*MB },
{ MODEL_MEDIUM, 1900ull*MB }, { MODEL_MEDIUM, 1650ull*MB },
{ MODEL_LARGE, 3600ull*MB }, { MODEL_LARGE, 3260ull*MB },
}; };
const std::map<e_model, size_t> MEM_REQ_ENCODE = { const std::map<e_model, size_t> MEM_REQ_ENCODE = {
@ -44,11 +150,11 @@ const std::map<e_model, size_t> MEM_REQ_ENCODE = {
}; };
const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = { const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
{ MODEL_TINY, 170ull*MB }, { MODEL_TINY, 64ull*MB },
{ MODEL_BASE, 230ull*MB }, { MODEL_BASE, 84ull*MB },
{ MODEL_SMALL, 350ull*MB }, { MODEL_SMALL, 128ull*MB },
{ MODEL_MEDIUM, 450ull*MB }, { MODEL_MEDIUM, 172ull*MB },
{ MODEL_LARGE, 570ull*MB }, { MODEL_LARGE, 216ull*MB },
}; };
const std::map<e_model, size_t> MEM_REQ_DECODE = { const std::map<e_model, size_t> MEM_REQ_DECODE = {
@ -102,6 +208,10 @@ struct whisper_vocab {
id token_solm = 50361; // ?? id token_solm = 50361; // ??
id token_beg = 50363; id token_beg = 50363;
// available tasks
const id token_translate = 50358;
const id token_transcribe = 50359;
bool is_multilingual() const { bool is_multilingual() const {
return n_vocab == 51865; return n_vocab == 51865;
} }
@ -109,16 +219,18 @@ struct whisper_vocab {
// command-line parameters // command-line parameters
struct whisper_params { struct whisper_params {
int32_t seed = -1; // RNG seed int32_t seed = -1; // RNG seed, not used currently
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
// sampling parameter - used for the greedy strategy
int32_t max_tokens_per_iter = 64; int32_t max_tokens_per_iter = 64;
bool verbose = false; bool verbose = false;
bool translate = false;
bool print_special_tokens = false; bool print_special_tokens = false;
std::string model = "models/ggml-base.en.bin"; // model path std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_inp = "samples/jfk.wav"; std::string fname_inp = "samples/jfk.wav";
}; };
@ -136,6 +248,15 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
params.max_tokens_per_iter = std::stoi(argv[++i]); params.max_tokens_per_iter = std::stoi(argv[++i]);
} else if (arg == "-v" || arg == "--verbose") { } else if (arg == "-v" || arg == "--verbose") {
params.verbose = true; params.verbose = true;
} else if (arg == "--translate") {
params.translate = true;
} else if (arg == "-l" || arg == "--language") {
params.language = argv[++i];
if (g_lang.find(params.language) == g_lang.end()) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
} else if (arg == "-ps" || arg == "--print_special") { } else if (arg == "-ps" || arg == "--print_special") {
params.print_special_tokens = true; params.print_special_tokens = true;
} else if (arg == "-m" || arg == "--model") { } else if (arg == "-m" || arg == "--model") {
@ -160,16 +281,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter); fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter);
fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " -v, --verbose verbose output\n");
fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " --translate translate from source language to english\n");
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " -ps, --print_special print special tokens\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " input WAV file path (default: %s)\n", params.fname_inp.c_str()); fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str());
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -417,6 +538,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
printf("%s: f16 = %d\n", __func__, hparams.f16); printf("%s: f16 = %d\n", __func__, hparams.f16);
printf("%s: type = %d\n", __func__, model.type); printf("%s: type = %d\n", __func__, model.type);
// this is the total memory required to run the inference
const size_t mem_required = const size_t mem_required =
MEM_REQ_MODEL.at(model.type) + MEM_REQ_MODEL.at(model.type) +
MEM_REQ_ENCODE.at(model.type) + MEM_REQ_ENCODE.at(model.type) +
@ -609,11 +731,11 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
} }
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_k ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_v ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_k ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_v ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
@ -836,22 +958,24 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
const int n_text_layer = hparams.n_text_layer; const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx; const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{ {
const int n_mem = n_text_layer*n_text_ctx; const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
} }
// key/value memory for the cross-attention layer
{ {
const int n_audio_ctx = hparams.n_audio_ctx; const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem; const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
} }
const size_t memory_size = const size_t memory_size =
@ -1057,14 +1181,14 @@ bool whisper_encode(
Qcur), Qcur),
Qcur); Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// no bias for Key // note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w, layer.attn_k_w,
cur); cur);
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w, layer.attn_v_w,
@ -1078,6 +1202,33 @@ bool whisper_encode(
// ------ // ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
n_state/n_head, n_head, N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
#else
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
@ -1089,38 +1240,19 @@ bool whisper_encode(
ggml_permute(ctxL, ggml_permute(ctxL,
ggml_cpy(ctxL, ggml_cpy(ctxL,
Kcur, Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), // F16 ! ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3); 0, 2, 1, 3);
//// BLAS attempt
//struct ggml_tensor * KQ =
// ggml_mul_mat(ctxL,
// ggml_cpy(ctxL, K, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)),
// ggml_cpy(ctxL, Q, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)));
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * K = struct ggml_tensor * KQ_scaled =
// ggml_cpy(ctxL, ggml_scale(ctxL,
// ggml_permute(ctxL, KQ,
// ggml_reshape_3d(ctxL, ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// Kcur, );
// n_state/n_head, n_head, N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
// );
//// K * Q struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
//struct ggml_tensor * KQ = ggml_mul_mat(ctxL, ggml_transpose(ctxL, K), Q);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL,
// KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// );
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
//struct ggml_tensor * V_trans = //struct ggml_tensor * V_trans =
// ggml_permute(ctxL, // ggml_permute(ctxL,
@ -1138,10 +1270,11 @@ bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, N), n_state/n_head, n_head, N),
0, 2, 1, 3), 0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) // F16 ! ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
); );
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
@ -1180,6 +1313,11 @@ bool whisper_encode(
ggml_repeat(ctxL, layer.mlp_ln_b, cur)); ggml_repeat(ctxL, layer.mlp_ln_b, cur));
} }
#ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
// fully connected // fully connected
cur = ggml_mul_mat(ctxL, cur = ggml_mul_mat(ctxL,
layer.mlp_0_w, layer.mlp_0_w,
@ -1200,6 +1338,7 @@ bool whisper_encode(
cur = ggml_add(ctxL, cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.mlp_1_b, cur), ggml_repeat(ctxL, layer.mlp_1_b, cur),
cur); cur);
#endif
} }
// output from this layer // output from this layer
@ -1368,7 +1507,7 @@ bool whisper_decode(
((int32_t *) position->data)[i] = n_past + i; ((int32_t *) position->data)[i] = n_past + i;
} }
// wte + wpe // token encoding + position encoding
struct ggml_tensor * cur = struct ggml_tensor * cur =
ggml_add(ctx0, ggml_add(ctx0,
ggml_get_rows(ctx0, model.d_te, embd), ggml_get_rows(ctx0, model.d_te, embd),
@ -1420,7 +1559,7 @@ bool whisper_decode(
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// no bias for Key // note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w, layer.attn_k_w,
cur); cur);
@ -1506,7 +1645,7 @@ bool whisper_decode(
// norm // norm
{ {
cur = ggml_norm(ctxL, inpCA); // Note we use inpCA here cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
// cur = ln_0_w*cur + ln_0_b // cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctxL, cur = ggml_add(ctxL,
@ -1589,7 +1728,6 @@ bool whisper_decode(
cur); cur);
} }
// add the input // add the input
cur = ggml_add(ctxL, cur, inpCA); cur = ggml_add(ctxL, cur, inpCA);
@ -1601,8 +1739,7 @@ bool whisper_decode(
{ {
cur = ggml_norm(ctxL, inpFF); cur = ggml_norm(ctxL, inpFF);
// cur = ln_2_g*cur + ln_2_b // cur = mlp_ln_w*cur + mlp_ln_b
// [ 768, N]
cur = ggml_add(ctxL, cur = ggml_add(ctxL,
ggml_mul(ctxL, ggml_mul(ctxL,
ggml_repeat(ctxL, layer.mlp_ln_w, cur), ggml_repeat(ctxL, layer.mlp_ln_w, cur),
@ -1689,11 +1826,11 @@ bool whisper_decode(
probs_out.resize(N*n_vocab); probs_out.resize(N*n_vocab);
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab); memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
//if (N > 1) { if (N > 1) {
// const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
// printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
// printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
//} }
ggml_free(ctx0); ggml_free(ctx0);
@ -1981,8 +2118,36 @@ int main(int argc, char ** argv) {
t_mel_us = ggml_time_us() - t_start_us; t_mel_us = ggml_time_us() - t_start_us;
} }
// print some info about the processing
{
printf("\n");
if (!vocab.is_multilingual()) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n",
__func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads,
g_lang.at(params.language).second.c_str(),
params.translate ? "translate" : "transcribe");
}
// the accumulated text context so far
std::vector<whisper_vocab::id> prompt_past = { }; std::vector<whisper_vocab::id> prompt_past = { };
// these tokens determine the task that will be performed
std::vector<whisper_vocab::id> prompt_init = { vocab.token_sot };
if (vocab.is_multilingual()) {
prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first);
if (params.translate) {
prompt_init.push_back(vocab.token_translate);
} else {
prompt_init.push_back(vocab.token_transcribe);
}
}
// main loop // main loop
int seek = 0; int seek = 0;
while (true) { while (true) {
@ -2006,24 +2171,23 @@ int main(int argc, char ** argv) {
std::vector<float> probs; std::vector<float> probs;
std::vector<float> logits; std::vector<float> logits;
// SOT std::vector<whisper_vocab::id> prompt;
// ref: https://github.com/openai/whisper/blob/15ab54826343c27cfaf44ce31e9c8fb63d0aa775/whisper/decoding.py#L506-L526
// TODO: use different initial tokens for different tasks
std::vector<whisper_vocab::id> prompt = { vocab.token_sot };
int n_past = 0; int n_past = 0;
// if we have already generated some text, use it as a prompt to condition the next generation
if (prompt_past.size() > 0) { if (prompt_past.size() > 0) {
int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size())); int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size()));
prompt = { vocab.token_prev }; prompt = { vocab.token_prev };
prompt.insert(prompt.end(), prompt_past.end() - n_take, prompt_past.end()); prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
prompt.push_back(vocab.token_sot);
prompt_past.clear(); prompt_past.clear();
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - 1); prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
} }
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
bool done = false; bool done = false;
int seek_delta = 100*CHUNK_SIZE; int seek_delta = 100*CHUNK_SIZE;
whisper_vocab::id last_id = 0; whisper_vocab::id last_id = 0;
@ -2049,6 +2213,16 @@ int main(int argc, char ** argv) {
n_past += prompt.size(); n_past += prompt.size();
prompt.clear(); prompt.clear();
// very basic greedy sampling strategy:
//
// - always take the most probable token
// - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token
// and advance the sliding window by that amount
// - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too
//
// more sophisticated sampling strategies could be implemented here, but we keep it simple
// feel free to experiment!
//
{ {
// sample next token // sample next token
const float temp = 1.0; // TODO const float temp = 1.0; // TODO