diff --git a/Makefile b/Makefile index 773bde0e..1aed7bf5 100644 --- a/Makefile +++ b/Makefile @@ -30,11 +30,16 @@ samples: # runs it on all samples in the folder "./samples": .PHONY: tiny.en +.PHONY: tiny .PHONY: base.en -.PHONY: medium.en +.PHONY: base .PHONY: small.en +.PHONY: small +.PHONY: medium.en +.PHONY: medium +.PHONY: large -tiny.en base.en medium.en small.en: main +tiny.en tiny base.en base small.en small medium.en medium large: main bash ./download-ggml-model.sh $@ @echo "" @echo "===============================================" diff --git a/README.md b/README.md index 891a94a1..f4877cf2 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ C/C++ port of [OpenAI's Whisper](https://github.com/openai/whisper) speech-to-te - Plain C/C++ implementation without dependencies - ARM_NEON and AVX intrinsics support -- F16 support +- Mixed F16 / F32 support +- Low memory usage (Flash Attention + Flash Forward) ## Usage @@ -27,9 +28,33 @@ For a quick demo, simply run `make base.en`: ```bash $ make base.en -Downloading base.en (142 MB just once) -mkdir -p models -models/ggml-base.en.bin 100%[=================================>] 141.11M 7.50MB/s in 19s +gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c +g++ -pthread -O3 -std=c++11 -c main.cpp +g++ -o main ggml.o main.o +./main -h + +usage: ./main [options] + +options: + -h, --help show this help message and exit + -s SEED, --seed SEED RNG seed (default: -1) + -t N, --threads N number of threads to use during computation (default: 4) + -T N, --tokens N maximum number of tokens to generate per iteration (default: 64) + -v, --verbose verbose output + --translate translate from source language to english + -ps, --print_special print special tokens + -l LANG, --language LANG spoken language (default: en) + -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin) + -f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav) + +bash ./download-ggml-model.sh base.en +Downloading ggml model base.en ... +models/ggml-base.en.bin 100%[=====================================>] 141.11M 7.84MB/s in 18s +Done! Model 'base.en' saved in 'models/ggml-base.en.bin' +You can now use it like this: + + $ ./main -m models/ggml-base.en.bin -f samples/jfk.wav + =============================================== Running base.en on all samples in ./samples ... @@ -52,23 +77,24 @@ whisper_model_load: n_text_layer = 6 whisper_model_load: n_mels = 80 whisper_model_load: f16 = 1 whisper_model_load: type = 2 -whisper_model_load: mem_required = 782.00 MB +whisper_model_load: mem_required = 611.00 MB whisper_model_load: adding 1607 extra tokens -whisper_model_load: ggml ctx size = 186.26 MB -whisper_model_load: memory size = 45.66 MB +whisper_model_load: ggml ctx size = 163.43 MB +whisper_model_load: memory size = 22.83 MB whisper_model_load: model size = 140.54 MB log_mel_spectrogram: n_sample = 176000, n_len = 1100 log_mel_spectrogram: recording length: 11.000000 s +main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe ... + And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country. -main: load time = 60.62 ms -main: mel time = 38.69 ms -main: sample time = 2.36 ms -main: encode time = 875.63 ms / 145.94 ms per layer -main: decode time = 103.17 ms -main: total time = 1081.13 ms - +main: load time = 71.89 ms +main: mel time = 36.95 ms +main: sample time = 2.10 ms +main: encode time = 700.94 ms / 116.82 ms per layer +main: decode time = 86.14 ms +main: total time = 898.72 ms ``` The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`. @@ -81,13 +107,18 @@ make samples This will download a few more audio files from Wikipedia and convert them to 16-bit WAV format via `ffmpeg`. -You can download and run the other `.en` models as follows: +You can download and run the other models as follows: ``` make tiny.en +make tiny make base.en +make base make small.en +make small make medium.en +make medium +make large ``` For detailed usage instructions, run: `./main -h` @@ -101,10 +132,8 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav ## Limitations -- Only `.en` models are supported - Very basic greedy sampling scheme - always pick up the top token - No timestamps -- English only - Inference only - Runs on the CPU - Only mono-channel 16-bit WAV is supported @@ -113,10 +142,11 @@ ffmpeg -i input.mp3 -ar 16000 -ac 1 -c:a pcm_s16le output.wav | Model | Disk | Mem | | --- | --- | --- | -| tiny.en | 75 MB | ~600 MB | -| base.en | 142 MB | ~800 MB | -| small.en | 466 MB | ~1.6 GB | -| medium.en | 1.5 GB | ~3.5 GB | +| tiny | 75 MB | ~460 MB | +| base | 142 MB | ~620 MB | +| small | 466 MB | ~1.3 GB | +| medium | 1.5 GB | ~2.8 GB | +| large | 2.9 GB | ~4.9 GB | ## ggml format diff --git a/download-ggml-model.sh b/download-ggml-model.sh index 3d5fa50b..d3009d27 100755 --- a/download-ggml-model.sh +++ b/download-ggml-model.sh @@ -6,7 +6,7 @@ ggml_path=$(dirname $(realpath $0)) # Whisper models -models=( "tiny.en" "base.en" "small.en" "medium.en" ) +models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large" ) # list available models function list_models { diff --git a/ggml.c b/ggml.c index c29422ce..9b18d819 100644 --- a/ggml.c +++ b/ggml.c @@ -20,7 +20,13 @@ #define UNUSED(x) (void)(x) #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) -#define GGML_ASSERT(x) assert(x) +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) #ifdef GGML_USE_ACCELERATE #include @@ -118,6 +124,16 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) { } #endif +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static ggml_fp16_t table_exp_f16[1 << 16]; + // // timing // @@ -331,7 +347,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t // leftovers for (int i = n32; i < n; ++i) { - GGML_ASSERT(false); // should not end up here sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } #elif defined(__AVX2__) @@ -375,7 +390,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t // leftovers for (int i = n32; i < n; ++i) { - GGML_ASSERT(false); + //GGML_ASSERT(false); sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } #else @@ -558,12 +573,20 @@ inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { const ggml_float GELU_COEF_A = 0.044715; const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; -inline static void ggml_vec_gelu_f32 (const int n, float * y, const float * x) { +inline static float ggml_gelu_f32(float x) { + return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) { - //y[i] = 0.5f*x[i]*(1.f + tanhf(SQRT_2_OVER_PI*(x[i] + 0.044715f*x[i]*x[i]*x[i]))); - //0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) - const ggml_float xx = x[i]; - y[i] = 0.5*xx*(1.0 + tanh(SQRT_2_OVER_PI*xx*(1.0 + GELU_COEF_A*xx*xx))); + y[i] = ggml_gelu_f32(x[i]); + } +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_gelu_f16[i16[i]]; } } @@ -641,6 +664,9 @@ const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "ROPE", "CONV_1D_1S", "CONV_1D_2S", + + "FLASH_ATTN", + "FLASH_FF", }; const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { @@ -678,6 +704,9 @@ const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope(x)", "conv_1d_1s(x)", "conv_1d_2s(x)", + + "flash_attn(x)", + "flash_ff(x)", }; // @@ -878,6 +907,24 @@ int ggml_up64(int n) { //////////////////////////////////////////////////////////////////////////////// struct ggml_context * ggml_init(struct ggml_init_params params) { + static bool is_first_call = true; + if (is_first_call) { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ii = (uint16_t) i; + const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii)); + table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f)); + table_exp_f16[i] = ggml_fp32_to_fp16(exp(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + + is_first_call = false; + } + // find non-used context in g_state struct ggml_context * ctx = NULL; @@ -900,7 +947,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { } if (ctx == NULL) { - GGML_PRINT_DEBUG("%s\n", "ggml_init: no unused context found"); + GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); return NULL; } @@ -923,8 +970,8 @@ void ggml_free(struct ggml_context * ctx) { if (&g_state.contexts[i].context == ctx) { g_state.contexts[i].used = false; - GGML_PRINT_DEBUG("ggml_free: context %d with %d objects has been freed. memory used = %zu\n", - i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size); + GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n", + __func__, i, ctx->n_objects, ctx->objects_end->offset + ctx->objects_end->size); if (ctx->mem_buffer_owned) { free(ctx->mem_buffer); @@ -1010,6 +1057,7 @@ struct ggml_tensor * ggml_new_tensor_impl( /*.grad =*/ NULL, /*.src0 =*/ NULL, /*.src1 =*/ NULL, + /*.opt =*/ { NULL }, /*.n_tasks =*/ 0, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, @@ -1079,6 +1127,14 @@ struct ggml_tensor * ggml_new_tensor_4d( return ggml_new_tensor(ctx, type, 4, ne); } +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ggml_set_i32(result, value); + + return result; +} + struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); @@ -1096,6 +1152,58 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { return tensor; } +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } + + return tensor; +} + struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { const int n = ggml_nrows(tensor); const int nc = tensor->ne[0]; @@ -1148,40 +1256,109 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { return tensor; } -float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { case GGML_TYPE_I8: { - assert(tensor->nb[0] == sizeof(int8_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; } break; case GGML_TYPE_I16: { - assert(tensor->nb[0] == sizeof(int16_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; } break; case GGML_TYPE_I32: { - assert(tensor->nb[0] == sizeof(int32_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; } break; case GGML_TYPE_F16: { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]); } break; case GGML_TYPE_F32: { - assert(tensor->nb[0] == sizeof(float)); + GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; } break; case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return ggml_fp16_to_fp32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); } break; } - assert(false); return 0.0f; } @@ -1189,32 +1366,32 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { switch (tensor->type) { case GGML_TYPE_I8: { - assert(tensor->nb[0] == sizeof(int8_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); ((int8_t *)(tensor->data))[i] = value; } break; case GGML_TYPE_I16: { - assert(tensor->nb[0] == sizeof(int16_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); ((int16_t *)(tensor->data))[i] = value; } break; case GGML_TYPE_I32: { - assert(tensor->nb[0] == sizeof(int32_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); ((int32_t *)(tensor->data))[i] = value; } break; case GGML_TYPE_F16: { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); ((ggml_fp16_t *)(tensor->data))[i] = ggml_fp32_to_fp16(value); } break; case GGML_TYPE_F32: { - assert(tensor->nb[0] == sizeof(float)); + GGML_ASSERT(tensor->nb[0] == sizeof(float)); ((float *)(tensor->data))[i] = value; } break; case GGML_TYPE_COUNT: { - assert(false); + GGML_ASSERT(false); } break; } } @@ -2308,6 +2485,70 @@ struct ggml_tensor * ggml_conv_1d_2s( return result; } +// ggml_flash_attn + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked) { + assert(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); + + result->op = GGML_OP_FLASH_ATTN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = q; + result->src1 = k; + result->opt[0] = v; + result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0); + + return result; +} + +// ggml_flash_ff + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1) { + assert(ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); + + result->op = GGML_OP_FLASH_FF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b0; + result->opt[0] = b1; + result->opt[1] = c0; + result->opt[2] = c1; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_set_param( @@ -2415,7 +2656,7 @@ void ggml_compute_forward_dup_f32( GGML_ASSERT(false); // TODO: implement } } else { - printf("%s: this is not optimal - fix me\n", __func__); + //printf("%s: this is not optimal - fix me\n", __func__); if (dst->type == GGML_TYPE_F32) { int id = 0; @@ -4185,10 +4426,17 @@ void ggml_compute_forward_soft_max_f32( } ggml_float sum = 0.0; + for (int i = 0; i < nc; i++) { - const ggml_float v = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); - sum += v; - p[i] = v; + if (p[i] == -INFINITY) { + p[i] = 0.0; + } else { + //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); + ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max); + const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + sum += val; + p[i] = val; + } } assert(sum > 0.0f); @@ -4362,7 +4610,6 @@ void ggml_compute_forward_conv_1d_1s_f16_f32( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - // WHISPER if (params->type == GGML_TASK_INIT) { // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); @@ -4483,7 +4730,6 @@ void ggml_compute_forward_conv_1d_1s_f32( GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); - // WHISPER if (params->type == GGML_TASK_INIT) { // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); @@ -4630,7 +4876,6 @@ void ggml_compute_forward_conv_1d_2s_f16_f32( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - // WHISPER if (params->type == GGML_TASK_INIT) { // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); @@ -4751,7 +4996,6 @@ void ggml_compute_forward_conv_1d_2s_f32( GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); - // WHISPER if (params->type == GGML_TASK_INIT) { // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); @@ -4841,6 +5085,607 @@ void ggml_compute_forward_conv_1d_2s( } } +// ggml_compute_forward_flash_attn + +void ggml_compute_forward_flash_attn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + for (int i = 0; i < M; i++) { + max = MAX(max, S[i]); + } + + ggml_float sum = 0.0; + + for (int i = 0; i < M; i++) { + if (S[i] == -INFINITY) { + S[i] = 0.0; + } else { + //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); + ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); + const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + sum += val; + S[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + } + + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f32(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S); + } + } +} + +void ggml_compute_forward_flash_attn_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int neq0 = q->ne[0]; + const int neq1 = q->ne[1]; + const int neq2 = q->ne[2]; + const int neq3 = q->ne[3]; + + const int nek0 = k->ne[0]; + const int nek1 = k->ne[1]; + //const int nek2 = k->ne[2]; + //const int nek3 = k->ne[3]; + + //const int nev0 = v->ne[0]; + const int nev1 = v->ne[1]; + //const int nev2 = v->ne[2]; + //const int nev3 = v->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + //const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nbk0 = k->nb[0]; + const int nbk1 = k->nb[1]; + const int nbk2 = k->nb[2]; + const int nbk3 = k->nb[3]; + + const int nbq0 = q->nb[0]; + const int nbq1 = q->nb[1]; + const int nbq2 = q->nb[2]; + const int nbq3 = q->nb[3]; + + const int nbv0 = v->nb[0]; + const int nbv1 = v->nb[1]; + const int nbv2 = v->nb[2]; + const int nbv3 = v->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = neq0; + const int N = neq1; + const int P = nek1 - N; + const int M = P + N; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0/sqrt((double) D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + for (int i = 0; i < M; i++) { + max = MAX(max, S[i]); + } + + ggml_float sum = 0.0; + + for (int i = 0; i < M; i++) { + if (S[i] == -INFINITY) { + S[i] = 0.0; + } else { + //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); + ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); + const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + sum += val; + S[i] = val; + } + } + + assert(sum > 0.0f); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int i = 0; i < M; i++) { + S16[i] = ggml_fp32_to_fp16(S[i]); + } + + for (int ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } +} + +void ggml_compute_forward_flash_attn( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + +// ggml_compute_forward_flash_ff + +void ggml_compute_forward_flash_ff_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, // F16 + const struct ggml_tensor * b0, // F16 fc_w + const struct ggml_tensor * b1, // F32 fc_b + const struct ggml_tensor * c0, // F16 proj_w + const struct ggml_tensor * c1, // F32 proj_b + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int nea0 = a->ne[0]; + const int nea1 = a->ne[1]; + const int nea2 = a->ne[2]; + const int nea3 = a->ne[3]; + + const int neb00 = b0->ne[0]; + const int neb01 = b0->ne[1]; + //const int neb02 = b0->ne[2]; + //const int neb03 = b0->ne[3]; + + const int neb10 = b1->ne[0]; + const int neb11 = b1->ne[1]; + //const int neb12 = b1->ne[2]; + //const int neb13 = b1->ne[3]; + + const int nec00 = c0->ne[0]; + const int nec01 = c0->ne[1]; + //const int nec02 = c0->ne[2]; + //const int nec03 = c0->ne[3]; + + const int nec10 = c1->ne[0]; + const int nec11 = c1->ne[1]; + //const int nec12 = c1->ne[2]; + //const int nec13 = c1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + + const int nba0 = a->nb[0]; + const int nba1 = a->nb[1]; + const int nba2 = a->nb[2]; + const int nba3 = a->nb[3]; + + const int nbb00 = b0->nb[0]; + const int nbb01 = b0->nb[1]; + const int nbb02 = b0->nb[2]; + const int nbb03 = b0->nb[3]; + + const int nbb10 = b1->nb[0]; + //const int nbb11 = b1->nb[1]; + //const int nbb12 = b1->nb[2]; + //const int nbb13 = b1->nb[3]; + + const int nbc00 = c0->nb[0]; + const int nbc01 = c0->nb[1]; + const int nbc02 = c0->nb[2]; + const int nbc03 = c0->nb[3]; + + const int nbc10 = c1->nb[0]; + //const int nbc11 = c1->nb[1]; + //const int nbc12 = c1->nb[2]; + //const int nbc13 = c1->nb[3]; + + const int nb0 = dst->nb[0]; + const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int D = nea0; + //const int N = nea1; + const int M = neb01; + + GGML_ASSERT(ne0 == nea0); + GGML_ASSERT(ne1 == nea1); + GGML_ASSERT(ne2 == nea2); + + GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb10 == sizeof(float)); + GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbc10 == sizeof(float)); + + GGML_ASSERT(neb00 == D); + GGML_ASSERT(neb01 == M); + GGML_ASSERT(neb10 == M); + GGML_ASSERT(neb11 == 1); + + GGML_ASSERT(nec00 == M); + GGML_ASSERT(nec01 == D); + GGML_ASSERT(nec10 == D); + GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //ggml_vec_gelu_f32(neb01, S, S); + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int i = 0; i < M; i++) { + S16[i] = ggml_fp32_to_fp16(S[i]); + } + + ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int ic = 0; ic < nec01; ++ic) { + + ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +void ggml_compute_forward_flash_ff( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b0, + const struct ggml_tensor * b1, + const struct ggml_tensor * c0, + const struct ggml_tensor * c1, + struct ggml_tensor * dst) { + switch (b0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); // TODO + } break; + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + assert(false); + } break; + } +} + ///////////////////////////////// void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { @@ -4967,13 +5812,24 @@ void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tenso { ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_FLASH_ATTN: + { + int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); + } break; + case GGML_OP_FLASH_FF: + { + ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor); + } break; case GGML_OP_NONE: { // nop } break; case GGML_OP_COUNT: { - assert(false); + GGML_ASSERT(false); } break; }; } @@ -5205,6 +6061,14 @@ void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tenso { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_FLASH_ATTN: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_FLASH_FF: + { + GGML_ASSERT(false); // not supported + } break; case GGML_OP_NONE: { // nop @@ -5246,6 +6110,12 @@ void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) ggml_visit_parents(cgraph, node->src1); } + for (int i = 0; i < GGML_MAX_OPT; ++i) { + if (node->opt[i]) { + ggml_visit_parents(cgraph, node->opt[i]); + } + } + if (node->op == GGML_OP_NONE && node->grad == NULL) { // reached a leaf node, not part of the gradient graph (e.g. a constant) assert(cgraph->n_leafs < GGML_MAX_NODES); @@ -5591,7 +6461,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_CONV_1D_1S: case GGML_OP_CONV_1D_2S: { - // WHISPER node->n_tasks = n_threads; GGML_ASSERT(node->src0->ne[3] == 1); @@ -5617,6 +6486,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) GGML_ASSERT(false); } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_FF: + { + node->n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + } + work_size = MAX(work_size, cur); } break; case GGML_OP_NONE: diff --git a/ggml.h b/ggml.h index 1078fbe8..465a9b6d 100644 --- a/ggml.h +++ b/ggml.h @@ -12,6 +12,7 @@ extern "C" { #define GGML_MAX_NODES 4096 #define GGML_MAX_PARAMS 16 #define GGML_MAX_CONTEXTS 16 +#define GGML_MAX_OPT 4 #ifdef __ARM_NEON // we use the built-in 16-bit float type @@ -71,6 +72,9 @@ enum ggml_op { GGML_OP_CONV_1D_1S, GGML_OP_CONV_1D_2S, + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + GGML_OP_COUNT, }; @@ -93,6 +97,7 @@ struct ggml_tensor { struct ggml_tensor * grad; struct ggml_tensor * src0; struct ggml_tensor * src1; + struct ggml_tensor * opt[GGML_MAX_OPT]; // thread scheduling int n_tasks; @@ -182,14 +187,19 @@ struct ggml_tensor * ggml_new_tensor_4d( int ne2, int ne3); +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src); struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); @@ -399,6 +409,21 @@ struct ggml_tensor * ggml_conv_1d_2s( struct ggml_tensor * a, struct ggml_tensor * b); +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + // // automatic differentiation // diff --git a/main.cpp b/main.cpp index 40835ba7..326a8a70 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,8 @@ #include "ggml.h" +#define USE_FLASH_ATTN +#define USE_FLASH_FF + // third-party utilities // use your favorite implementations #define DR_WAV_IMPLEMENTATION @@ -16,6 +19,7 @@ #include #include +// available whisper models enum e_model { MODEL_UNKNOWN, MODEL_TINY, @@ -25,14 +29,116 @@ enum e_model { MODEL_LARGE, }; +const std::map> g_lang = { + { "en", { 0, "english", } }, + { "zh", { 1, "chinese", } }, + { "de", { 2, "german", } }, + { "es", { 3, "spanish", } }, + { "ru", { 4, "russian", } }, + { "ko", { 5, "korean", } }, + { "fr", { 6, "french", } }, + { "ja", { 7, "japanese", } }, + { "pt", { 8, "portuguese", } }, + { "tr", { 9, "turkish", } }, + { "pl", { 10, "polish", } }, + { "ca", { 11, "catalan", } }, + { "nl", { 12, "dutch", } }, + { "ar", { 13, "arabic", } }, + { "sv", { 14, "swedish", } }, + { "it", { 15, "italian", } }, + { "id", { 16, "indonesian", } }, + { "hi", { 17, "hindi", } }, + { "fi", { 18, "finnish", } }, + { "vi", { 19, "vietnamese", } }, + { "iw", { 20, "hebrew", } }, + { "uk", { 21, "ukrainian", } }, + { "el", { 22, "greek", } }, + { "ms", { 23, "malay", } }, + { "cs", { 24, "czech", } }, + { "ro", { 25, "romanian", } }, + { "da", { 26, "danish", } }, + { "hu", { 27, "hungarian", } }, + { "ta", { 28, "tamil", } }, + { "no", { 29, "norwegian", } }, + { "th", { 30, "thai", } }, + { "ur", { 31, "urdu", } }, + { "hr", { 32, "croatian", } }, + { "bg", { 33, "bulgarian", } }, + { "lt", { 34, "lithuanian", } }, + { "la", { 35, "latin", } }, + { "mi", { 36, "maori", } }, + { "ml", { 37, "malayalam", } }, + { "cy", { 38, "welsh", } }, + { "sk", { 39, "slovak", } }, + { "te", { 40, "telugu", } }, + { "fa", { 41, "persian", } }, + { "lv", { 42, "latvian", } }, + { "bn", { 43, "bengali", } }, + { "sr", { 44, "serbian", } }, + { "az", { 45, "azerbaijani", } }, + { "sl", { 46, "slovenian", } }, + { "kn", { 47, "kannada", } }, + { "et", { 48, "estonian", } }, + { "mk", { 49, "macedonian", } }, + { "br", { 50, "breton", } }, + { "eu", { 51, "basque", } }, + { "is", { 52, "icelandic", } }, + { "hy", { 53, "armenian", } }, + { "ne", { 54, "nepali", } }, + { "mn", { 55, "mongolian", } }, + { "bs", { 56, "bosnian", } }, + { "kk", { 57, "kazakh", } }, + { "sq", { 58, "albanian", } }, + { "sw", { 59, "swahili", } }, + { "gl", { 60, "galician", } }, + { "mr", { 61, "marathi", } }, + { "pa", { 62, "punjabi", } }, + { "si", { 63, "sinhala", } }, + { "km", { 64, "khmer", } }, + { "sn", { 65, "shona", } }, + { "yo", { 66, "yoruba", } }, + { "so", { 67, "somali", } }, + { "af", { 68, "afrikaans", } }, + { "oc", { 69, "occitan", } }, + { "ka", { 70, "georgian", } }, + { "be", { 71, "belarusian", } }, + { "tg", { 72, "tajik", } }, + { "sd", { 73, "sindhi", } }, + { "gu", { 74, "gujarati", } }, + { "am", { 75, "amharic", } }, + { "yi", { 76, "yiddish", } }, + { "lo", { 77, "lao", } }, + { "uz", { 78, "uzbek", } }, + { "fo", { 79, "faroese", } }, + { "ht", { 80, "haitian creole", } }, + { "ps", { 81, "pashto", } }, + { "tk", { 82, "turkmen", } }, + { "nn", { 83, "nynorsk", } }, + { "mt", { 84, "maltese", } }, + { "sa", { 85, "sanskrit", } }, + { "lb", { 86, "luxembourgish", } }, + { "my", { 87, "myanmar", } }, + { "bo", { 88, "tibetan", } }, + { "tl", { 89, "tagalog", } }, + { "mg", { 90, "malagasy", } }, + { "as", { 91, "assamese", } }, + { "tt", { 92, "tatar", } }, + { "haw", { 93, "hawaiian", } }, + { "ln", { 94, "lingala", } }, + { "ha", { 95, "hausa", } }, + { "ba", { 96, "bashkir", } }, + { "jw", { 97, "javanese", } }, + { "su", { 98, "sundanese", } }, +}; + const size_t MB = 1024*1024; const std::map MEM_REQ_MODEL = { - { MODEL_TINY, 100ull*MB }, - { MODEL_BASE, 190ull*MB }, - { MODEL_SMALL, 610ull*MB }, - { MODEL_MEDIUM, 1900ull*MB }, - { MODEL_LARGE, 3600ull*MB }, + { MODEL_TINY, 86ull*MB }, + { MODEL_BASE, 165ull*MB }, + { MODEL_SMALL, 540ull*MB }, + { MODEL_MEDIUM, 1650ull*MB }, + { MODEL_LARGE, 3260ull*MB }, }; const std::map MEM_REQ_ENCODE = { @@ -44,11 +150,11 @@ const std::map MEM_REQ_ENCODE = { }; const std::map MEM_REQ_ENCODE_LAYER = { - { MODEL_TINY, 170ull*MB }, - { MODEL_BASE, 230ull*MB }, - { MODEL_SMALL, 350ull*MB }, - { MODEL_MEDIUM, 450ull*MB }, - { MODEL_LARGE, 570ull*MB }, + { MODEL_TINY, 64ull*MB }, + { MODEL_BASE, 84ull*MB }, + { MODEL_SMALL, 128ull*MB }, + { MODEL_MEDIUM, 172ull*MB }, + { MODEL_LARGE, 216ull*MB }, }; const std::map MEM_REQ_DECODE = { @@ -102,6 +208,10 @@ struct whisper_vocab { id token_solm = 50361; // ?? id token_beg = 50363; + // available tasks + const id token_translate = 50358; + const id token_transcribe = 50359; + bool is_multilingual() const { return n_vocab == 51865; } @@ -109,16 +219,18 @@ struct whisper_vocab { // command-line parameters struct whisper_params { - int32_t seed = -1; // RNG seed + int32_t seed = -1; // RNG seed, not used currently int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + // sampling parameter - used for the greedy strategy int32_t max_tokens_per_iter = 64; - bool verbose = false; + bool verbose = false; + bool translate = false; bool print_special_tokens = false; - std::string model = "models/ggml-base.en.bin"; // model path - + std::string language = "en"; + std::string model = "models/ggml-base.en.bin"; std::string fname_inp = "samples/jfk.wav"; }; @@ -136,6 +248,15 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.max_tokens_per_iter = std::stoi(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; + } else if (arg == "--translate") { + params.translate = true; + } else if (arg == "-l" || arg == "--language") { + params.language = argv[++i]; + if (g_lang.find(params.language) == g_lang.end()) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } } else if (arg == "-ps" || arg == "--print_special") { params.print_special_tokens = true; } else if (arg == "-m" || arg == "--model") { @@ -160,16 +281,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); - fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter); - fprintf(stderr, " -v, --verbose verbose output\n"); - fprintf(stderr, " -ps, --print_special print special tokens\n"); - fprintf(stderr, " -m FNAME, --model FNAME\n"); - fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME\n"); - fprintf(stderr, " input WAV file path (default: %s)\n", params.fname_inp.c_str()); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -T N, --tokens N maximum number of tokens to generate per iteration (default: %d)\n", params.max_tokens_per_iter); + fprintf(stderr, " -v, --verbose verbose output\n"); + fprintf(stderr, " --translate translate from source language to english\n"); + fprintf(stderr, " -ps, --print_special print special tokens\n"); + fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str()); fprintf(stderr, "\n"); } @@ -417,6 +538,7 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe printf("%s: f16 = %d\n", __func__, hparams.f16); printf("%s: type = %d\n", __func__, model.type); + // this is the total memory required to run the inference const size_t mem_required = MEM_REQ_MODEL.at(model.type) + MEM_REQ_ENCODE.at(model.type) + @@ -609,11 +731,11 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b } - ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_k - ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_v + ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k + ctx_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v - ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_k - ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // memory_cross_v + ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k + ctx_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead @@ -836,22 +958,24 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe const int n_text_layer = hparams.n_text_layer; const int n_text_ctx = hparams.n_text_ctx; + // key/value memory for the self-attention layer { const int n_mem = n_text_layer*n_text_ctx; const int n_elements = n_text_state*n_mem; - model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); } + // key/value memory for the cross-attention layer { const int n_audio_ctx = hparams.n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; - model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); + model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); } const size_t memory_size = @@ -1057,14 +1181,14 @@ bool whisper_encode( Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); - // no bias for Key + // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, layer.attn_k_w, cur); - Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, layer.attn_v_w, @@ -1078,6 +1202,33 @@ bool whisper_encode( // ------ +#ifdef USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Kcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + Vcur, + n_state/n_head, n_head, N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); +#else struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1089,38 +1240,19 @@ bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), // F16 ! + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), 0, 2, 1, 3); - //// BLAS attempt - //struct ggml_tensor * KQ = - // ggml_mul_mat(ctxL, - // ggml_cpy(ctxL, K, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head)), - // ggml_cpy(ctxL, Q, ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, N, n_head))); - // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); - //struct ggml_tensor * K = - // ggml_cpy(ctxL, - // ggml_permute(ctxL, - // ggml_reshape_3d(ctxL, - // Kcur, - // n_state/n_head, n_head, N), - // 1, 2, 0, 3), - // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head) - // ); + struct ggml_tensor * KQ_scaled = + ggml_scale(ctxL, + KQ, + ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + ); - //// K * Q - //struct ggml_tensor * KQ = ggml_mul_mat(ctxL, ggml_transpose(ctxL, K), Q); - - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, - // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) - // ); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled); //struct ggml_tensor * V_trans = // ggml_permute(ctxL, @@ -1138,10 +1270,11 @@ bool whisper_encode( Vcur, n_state/n_head, n_head, N), 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) // F16 ! + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); +#endif struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); @@ -1180,6 +1313,11 @@ bool whisper_encode( ggml_repeat(ctxL, layer.mlp_ln_b, cur)); } +#ifdef USE_FLASH_FF + cur = ggml_flash_ff(ctxL, + ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else // fully connected cur = ggml_mul_mat(ctxL, layer.mlp_0_w, @@ -1200,6 +1338,7 @@ bool whisper_encode( cur = ggml_add(ctxL, ggml_repeat(ctxL, layer.mlp_1_b, cur), cur); +#endif } // output from this layer @@ -1368,7 +1507,7 @@ bool whisper_decode( ((int32_t *) position->data)[i] = n_past + i; } - // wte + wpe + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, ggml_get_rows(ctx0, model.d_te, embd), @@ -1420,7 +1559,7 @@ bool whisper_decode( Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); - // no bias for Key + // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, layer.attn_k_w, cur); @@ -1506,7 +1645,7 @@ bool whisper_decode( // norm { - cur = ggml_norm(ctxL, inpCA); // Note we use inpCA here + cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctxL, @@ -1589,7 +1728,6 @@ bool whisper_decode( cur); } - // add the input cur = ggml_add(ctxL, cur, inpCA); @@ -1601,8 +1739,7 @@ bool whisper_decode( { cur = ggml_norm(ctxL, inpFF); - // cur = ln_2_g*cur + ln_2_b - // [ 768, N] + // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctxL, ggml_mul(ctxL, ggml_repeat(ctxL, layer.mlp_ln_w, cur), @@ -1689,11 +1826,11 @@ bool whisper_decode( probs_out.resize(N*n_vocab); memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab); - //if (N > 1) { - // const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; - // printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); - // printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); - //} + if (N > 1) { + //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; + //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); + //printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx); + } ggml_free(ctx0); @@ -1981,8 +2118,36 @@ int main(int argc, char ** argv) { t_mel_us = ggml_time_us() - t_start_us; } + // print some info about the processing + { + printf("\n"); + if (!vocab.is_multilingual()) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s ...\n", + __func__, int(pcmf32.size()), float(pcmf32.size())/SAMPLE_RATE, params.n_threads, + g_lang.at(params.language).second.c_str(), + params.translate ? "translate" : "transcribe"); + } + + // the accumulated text context so far std::vector prompt_past = { }; + // these tokens determine the task that will be performed + std::vector prompt_init = { vocab.token_sot }; + if (vocab.is_multilingual()) { + prompt_init.push_back(vocab.token_sot + 1 + g_lang.at(params.language).first); + if (params.translate) { + prompt_init.push_back(vocab.token_translate); + } else { + prompt_init.push_back(vocab.token_transcribe); + } + } + // main loop int seek = 0; while (true) { @@ -2006,24 +2171,23 @@ int main(int argc, char ** argv) { std::vector probs; std::vector logits; - // SOT - // ref: https://github.com/openai/whisper/blob/15ab54826343c27cfaf44ce31e9c8fb63d0aa775/whisper/decoding.py#L506-L526 - // TODO: use different initial tokens for different tasks - std::vector prompt = { vocab.token_sot }; + std::vector prompt; int n_past = 0; + // if we have already generated some text, use it as a prompt to condition the next generation if (prompt_past.size() > 0) { int n_take = std::min(model.hparams.n_text_ctx/2, int(prompt_past.size())); prompt = { vocab.token_prev }; - prompt.insert(prompt.end(), prompt_past.end() - n_take, prompt_past.end()); - prompt.push_back(vocab.token_sot); + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); prompt_past.clear(); - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - 1); + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); } + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + bool done = false; int seek_delta = 100*CHUNK_SIZE; whisper_vocab::id last_id = 0; @@ -2049,6 +2213,16 @@ int main(int argc, char ** argv) { n_past += prompt.size(); prompt.clear(); + // very basic greedy sampling strategy: + // + // - always take the most probable token + // - if we have accumulated more than 'params.max_tokens_per_iter' -> pick most probable timestamp token + // and advance the sliding window by that amount + // - in the meantime, if we encounter 2 consecutive timestamp tokens, we advance the sliding window too + // + // more sophisticated sampling strategies could be implemented here, but we keep it simple + // feel free to experiment! + // { // sample next token const float temp = 1.0; // TODO