diff --git a/Makefile b/Makefile index 4e1315f9..72678401 100644 --- a/Makefile +++ b/Makefile @@ -1,25 +1,80 @@ -CC_SDL=`sdl2-config --cflags --libs` +UNAME_S := $(shell uname -s) +UNAME_P := $(shell uname -p) +UNAME_M := $(shell uname -m) -main: ggml.o whisper.o main.o - g++ -pthread -o main ggml.o whisper.o main.o +# +# Compile flags +# + +CFLAGS = -O3 -std=c11 +CXXFLAGS = -O3 -std=c++11 + +CFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-unused-function +CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-unused-function + +# OS specific +# TODO: support Windows +ifeq ($(UNAME_S),Linux) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif +ifeq ($(UNAME_S),Darwin) + CFLAGS += -pthread + CXXFLAGS += -pthread +endif + +# Architecture specific +ifeq ($(UNAME_P),x86_64) + CFLAGS += -mavx -mavx2 -mfma -mf16c +endif +ifneq ($(filter arm%,$(UNAME_P)),) + # Mac M1 +endif +ifneq ($(filter aarch64%,$(UNAME_P)),) + endif + ifneq ($(filter armv6%,$(UNAME_M)),) + # Raspberry Pi 1, 2, 3 + CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access +endif +ifneq ($(filter armv7%,$(UNAME_M)),) + # Raspberry Pi 4 + CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations +endif +ifneq ($(filter armv8%,$(UNAME_M)),) + # Raspberry Pi 4 + CFLAGS += -mfp16-format=ieee -mno-unaligned-access +endif + +# +# Build library + main +# + +main: main.cpp ggml.o whisper.o + $(CXX) $(CXXFLAGS) main.cpp whisper.o ggml.o -o main ./main -h ggml.o: ggml.c ggml.h - gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c + $(CC) $(CFLAGS) -c ggml.c whisper.o: whisper.cpp whisper.h - gcc -pthread -O3 -std=c++11 -c whisper.cpp + $(CXX) $(CXXFLAGS) -c whisper.cpp -main.o: main.cpp ggml.h - g++ -pthread -O3 -std=c++11 -c main.cpp - -stream: stream.cpp - g++ -pthread -O3 -std=c++11 -o stream stream.cpp ggml.o whisper.o $(CC_SDL) - -# clean up the directory clean: rm -f *.o main +# +# Examples +# + +CC_SDL=`sdl2-config --cflags --libs` + +stream: stream.cpp ggml.o whisper.o + $(CXX) $(CXXFLAGS) stream.cpp ggml.o whisper.o -o stream $(CC_SDL) + +# +# Audio samples +# + # download a few audio samples into folder "./samples": .PHONY: samples samples: @@ -36,6 +91,9 @@ samples: @ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav @rm samples/mm1.wav +# +# Models +# # if not already downloaded, the following targets download the specified model and # runs it on all samples in the folder "./samples": diff --git a/README.md b/README.md index e570bcd7..a73dfb8a 100644 --- a/README.md +++ b/README.md @@ -7,13 +7,9 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - Mixed F16 / F32 precision - Low memory usage (Flash Attention + Flash Forward) - Zero memory allocations at runtime -- Runs on the CPU (Mac and Linux) +- Runs on the CPU - [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h) - -Incoming features: -- [Realtime audio input transcription](https://github.com/ggerganov/whisper.cpp/issues/10#issuecomment-1264665959) -- [Raspberry Pi support](https://github.com/ggerganov/whisper.cpp/issues/7) -- [Android support](https://github.com/ggerganov/whisper.cpp/issues/8) +- Supported platforms: Linux, Mac OS (Intel and Arm), Raspberry Pi, Android ## Usage @@ -35,13 +31,12 @@ For a quick demo, simply run `make base.en`: ```java $ make base.en - -gcc -pthread -O3 -mavx -mavx2 -mfma -mf16c -c ggml.c -g++ -pthread -O3 -std=c++11 -c main.cpp -g++ -pthread -o main ggml.o main.o +cc -O3 -std=c11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c ggml.c +c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread -c whisper.cpp +c++ -O3 -std=c++11 -Wall -Wextra -Wno-unused-parameter -Wno-unused-function -pthread main.cpp whisper.o ggml.o -o main ./main -h -usage: ./main [options] +usage: ./main [options] file0.wav file1.wav ... options: -h, --help show this help message and exit @@ -53,11 +48,11 @@ options: -nt, --no_timestamps do not print timestamps -l LANG, --language LANG spoken language (default: en) -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin) - -f FNAME, --file FNAME input WAV file path (default: samples/jfk.wav) + -f FNAME, --file FNAME input WAV file path bash ./download-ggml-model.sh base.en Downloading ggml model base.en ... -models/ggml-base.en.bin 100%[=====================================>] 141.11M 8.58MB/s in 22s +models/ggml-base.en.bin 100%[===================================>] 141.11M 6.49MB/s in 23s Done! Model 'base.en' saved in 'models/ggml-base.en.bin' You can now use it like this: @@ -90,20 +85,18 @@ whisper_model_load: adding 1607 extra tokens whisper_model_load: ggml ctx size = 163.43 MB whisper_model_load: memory size = 22.83 MB whisper_model_load: model size = 140.54 MB -log_mel_spectrogram: n_sample = 176000, n_len = 1100 -log_mel_spectrogram: recording length: 11.000000 s -main: processing 176000 samples (11.0 sec), 4 threads, lang = english, task = transcribe, timestamps = 1 ... +main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, lang = en, task = transcribe, timestamps = 1 ... -[00:00.000 --> 00:11.000] And so my fellow Americans ask not what your country can do for you. Ask what you can do for your country. +[00:00.000 --> 00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country. -main: load time = 82.05 ms -main: mel time = 44.15 ms -main: sample time = 1.98 ms -main: encode time = 674.77 ms / 112.46 ms per layer -main: decode time = 82.91 ms -main: total time = 886.29 ms +whisper_print_timings: load time = 77.48 ms +whisper_print_timings: mel time = 26.10 ms +whisper_print_timings: sample time = 2.19 ms +whisper_print_timings: encode time = 632.95 ms / 105.49 ms per layer +whisper_print_timings: decode time = 85.11 ms / 14.18 ms per layer +whisper_print_timings: total time = 824.14 ms ``` The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`. @@ -220,10 +213,16 @@ $ ./stream -m models/ggml-small.en.bin -t 8 https://user-images.githubusercontent.com/1991296/193465125-c163d304-64f6-4f5d-83e5-72239c9a203e.mp4 +## Implementation details + +- The core tensor operations are implemented in C ([ggml.h](ggml.h) / [ggml.c](ggml.c)) +- The high-level C-style API is implemented in C++ ([whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)) +- Simple usage is demonstrated in [main.cpp](main.cpp) +- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](stream.cpp) + ## Limitations -- Very basic greedy sampling scheme - always pick up the top token -- Only 16-bit WAV at 16 kHz is supported +- Very basic greedy sampling scheme - always pick up the top token. You can implement your own strategy - Inference only - No GPU support diff --git a/ggml.c b/ggml.c index ad59e077..b9ea9e0b 100644 --- a/ggml.c +++ b/ggml.c @@ -1,5 +1,10 @@ #include "ggml.h" +#if defined(_MSC_VER) +#include +#else +#include +#endif #include #include #if defined(_MSC_VER) @@ -21,7 +26,12 @@ #include #define GGML_DEBUG 0 -#define GGML_MEM_ALIGN 16 + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -210,7 +220,7 @@ int64_t ggml_cycles_per_ms(void) { #endif #if defined(_MSC_VER) -const size_t CACHE_LINE_SIZE_F32 = 64/sizeof(float); +const size_t CACHE_LINE_SIZE_F32 = 64 / sizeof(float); #else const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE / sizeof(float); #endif @@ -338,6 +348,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t #ifdef __ARM_NEON const int n32 = (n & ~31); +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) float16x8_t sum0 = vdupq_n_f16(0); float16x8_t sum1 = vdupq_n_f16(0); float16x8_t sum2 = vdupq_n_f16(0); @@ -377,6 +388,61 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0f32), vget_high_f32(sum0f32)); sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); +#else + float32x4_t sum0 = vdupq_n_f32(0); + float32x4_t sum1 = vdupq_n_f32(0); + float32x4_t sum2 = vdupq_n_f32(0); + float32x4_t sum3 = vdupq_n_f32(0); + float32x4_t sum4 = vdupq_n_f32(0); + float32x4_t sum5 = vdupq_n_f32(0); + float32x4_t sum6 = vdupq_n_f32(0); + float32x4_t sum7 = vdupq_n_f32(0); + + float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; + float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; + + for (int i = 0; i < n32; i += 32) { + x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); + x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); + x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); + x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); + x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); + x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); + x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); + x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); + + y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); + y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); + y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); + y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); + y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); + y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); + y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); + y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); + + sum0 = vfmaq_f32(sum0, x0, y0); + sum1 = vfmaq_f32(sum1, x1, y1); + sum2 = vfmaq_f32(sum2, x2, y2); + sum3 = vfmaq_f32(sum3, x3, y3); + sum4 = vfmaq_f32(sum4, x4, y4); + sum5 = vfmaq_f32(sum5, x5, y5); + sum6 = vfmaq_f32(sum6, x6, y6); + sum7 = vfmaq_f32(sum7, x7, y7); + } + + // reduce sum0..sum7 to sum0 + sum0 = vaddq_f32(sum0, sum1); + sum2 = vaddq_f32(sum2, sum3); + sum4 = vaddq_f32(sum4, sum5); + sum6 = vaddq_f32(sum6, sum7); + sum0 = vaddq_f32(sum0, sum2); + sum4 = vaddq_f32(sum4, sum6); + sum0 = vaddq_f32(sum0, sum4); + + // reduce sum0 to sumf + float32x2_t sumf32 = vadd_f32(vget_low_f32(sum0), vget_high_f32(sum0)); + sumf = vget_lane_f32(sumf32, 0) + vget_lane_f32(sumf32, 1); +#endif // leftovers for (int i = n32; i < n; ++i) { @@ -519,6 +585,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ // NEON 128-bit const int n32 = (n & ~31); +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) const float16x8_t v8 = vdupq_n_f16(v); float16x8_t x0, x1, x2, x3; @@ -545,6 +612,51 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ vst1q_f16(y + i + 16, y2); vst1q_f16(y + i + 24, y3); } +#else + const float32x4_t v40 = vdupq_n_f32(v); + const float32x4_t v41 = vdupq_n_f32(v); + + float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; + float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; + + for (int i = 0; i < n32; i += 32) { + y0 = vcvt_f32_f16(vld1_f16(y + i + 0 )); + y1 = vcvt_f32_f16(vld1_f16(y + i + 4 )); + y2 = vcvt_f32_f16(vld1_f16(y + i + 8 )); + y3 = vcvt_f32_f16(vld1_f16(y + i + 12)); + y4 = vcvt_f32_f16(vld1_f16(y + i + 16)); + y5 = vcvt_f32_f16(vld1_f16(y + i + 20)); + y6 = vcvt_f32_f16(vld1_f16(y + i + 24)); + y7 = vcvt_f32_f16(vld1_f16(y + i + 28)); + + x0 = vcvt_f32_f16(vld1_f16(x + i + 0 )); + x1 = vcvt_f32_f16(vld1_f16(x + i + 4 )); + x2 = vcvt_f32_f16(vld1_f16(x + i + 8 )); + x3 = vcvt_f32_f16(vld1_f16(x + i + 12)); + x4 = vcvt_f32_f16(vld1_f16(x + i + 16)); + x5 = vcvt_f32_f16(vld1_f16(x + i + 20)); + x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); + x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); + + y0 = vfmaq_f32(y0, x0, v40); + y1 = vfmaq_f32(y1, x1, v40); + y2 = vfmaq_f32(y2, x2, v40); + y3 = vfmaq_f32(y3, x3, v40); + y4 = vfmaq_f32(y4, x4, v41); + y5 = vfmaq_f32(y5, x5, v41); + y6 = vfmaq_f32(y6, x6, v41); + y7 = vfmaq_f32(y7, x7, v41); + + vst1_f16(y + i + 0 , vcvt_f16_f32(y0)); + vst1_f16(y + i + 4 , vcvt_f16_f32(y1)); + vst1_f16(y + i + 8 , vcvt_f16_f32(y2)); + vst1_f16(y + i + 12, vcvt_f16_f32(y3)); + vst1_f16(y + i + 16, vcvt_f16_f32(y4)); + vst1_f16(y + i + 20, vcvt_f16_f32(y5)); + vst1_f16(y + i + 24, vcvt_f16_f32(y6)); + vst1_f16(y + i + 28, vcvt_f16_f32(y7)); + } +#endif // leftovers for (int i = n32; i < n; ++i) { @@ -944,16 +1056,18 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { if (is_first_call) { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + ggml_fp16_t ii; for (int i = 0; i < (1 << 16); ++i) { - uint16_t ii = (uint16_t) i; - const float f = ggml_fp16_to_fp32(*(ggml_fp16_t *)(&ii)); + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = ggml_fp16_to_fp32(ii); table_gelu_f16[i] = ggml_fp32_to_fp16(ggml_gelu_f32(f)); table_exp_f16[i] = ggml_fp32_to_fp16(exp(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - GGML_PRINT_DEBUG("%s: GELU table initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + GGML_PRINT_DEBUG("%s: GELU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); is_first_call = false; } @@ -4460,13 +4574,15 @@ void ggml_compute_forward_soft_max_f32( ggml_float sum = 0.0; + uint16_t ss; for (int i = 0; i < nc; i++) { if (p[i] == -INFINITY) { p[i] = 0.0; } else { //const float val = (p[i] == -INFINITY) ? 0.0 : exp(p[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(p[i] - max); - const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + memcpy(&ss, &s, sizeof(ss)); + const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); sum += val; p[i] = val; } @@ -5267,13 +5383,15 @@ void ggml_compute_forward_flash_attn_f32( ggml_float sum = 0.0; + uint16_t ss; for (int i = 0; i < M; i++) { if (S[i] == -INFINITY) { S[i] = 0.0; } else { //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); - const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + memcpy(&ss, &s, sizeof(ss)); + const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); sum += val; S[i] = val; } @@ -5446,13 +5564,15 @@ void ggml_compute_forward_flash_attn_f16( ggml_float sum = 0.0; + uint16_t ss; for (int i = 0; i < M; i++) { if (S[i] == -INFINITY) { S[i] = 0.0; } else { //const float val = (S[i] == -INFINITY) ? 0.0 : exp(S[i] - max); ggml_fp16_t s = ggml_fp32_to_fp16(S[i] - max); - const float val = ggml_fp16_to_fp32(table_exp_f16[*(uint16_t *) &s]); + memcpy(&ss, &s, sizeof(ss)); + const float val = ggml_fp16_to_fp32(table_exp_f16[ss]); sum += val; S[i] = val; } diff --git a/ggml.h b/ggml.h index 465a9b6d..5b7b2582 100644 --- a/ggml.h +++ b/ggml.h @@ -108,7 +108,7 @@ struct ggml_tensor { int64_t perf_time_us; void * data; - char pad[8]; + char padding[8]; }; // computation graph diff --git a/main.cpp b/main.cpp index 1885eb6d..d363ab7e 100644 --- a/main.cpp +++ b/main.cpp @@ -36,7 +36,8 @@ struct whisper_params { std::string language = "en"; std::string model = "models/ggml-base.en.bin"; - std::string fname_inp = "samples/jfk.wav"; + + std::vector fname_inp = {}; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -45,6 +46,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + if (arg == "-s" || arg == "--seed") { params.seed = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { @@ -67,7 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { - params.fname_inp = argv[++i]; + params.fname_inp.push_back(argv[++i]); } else if (arg == "-h" || arg == "--help") { whisper_print_usage(argc, argv, params); exit(0); @@ -83,7 +89,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { void whisper_print_usage(int argc, char ** argv, const whisper_params & params) { fprintf(stderr, "\n"); - fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); @@ -95,7 +101,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME input WAV file path (default: %s)\n", params.fname_inp.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME input WAV file path\n"); fprintf(stderr, "\n"); } @@ -110,106 +116,116 @@ int main(int argc, char ** argv) { params.seed = time(NULL); } + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + whisper_print_usage(argc, argv, params); + return 1; + } + // whisper init struct whisper_context * ctx = whisper_init(params.model.c_str()); - // WAV input - std::vector pcmf32; - { - drwav wav; - if (!drwav_init_file(&wav, params.fname_inp.c_str(), NULL)) { - fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], params.fname_inp.c_str()); - whisper_print_usage(argc, argv, {}); - return 2; - } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { + const auto fname_inp = params.fname_inp[f]; - if (wav.channels != 1 && wav.channels != 2) { - fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str()); - return 3; - } - - if (wav.sampleRate != WHISPER_SAMPLE_RATE) { - fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], params.fname_inp.c_str()); - return 4; - } - - if (wav.bitsPerSample != 16) { - fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], params.fname_inp.c_str()); - return 5; - } - - int n = wav.totalPCMFrameCount; - - std::vector pcm16; - pcm16.resize(n*wav.channels); - drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); - drwav_uninit(&wav); - - // convert to mono, float - pcmf32.resize(n); - if (wav.channels == 1) { - for (size_t i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[i])/32768.0f; + // WAV input + std::vector pcmf32; + { + drwav wav; + if (!drwav_init_file(&wav, fname_inp.c_str(), NULL)) { + fprintf(stderr, "%s: failed to open WAV file '%s' - check your input\n", argv[0], fname_inp.c_str()); + whisper_print_usage(argc, argv, {}); + return 2; } - } else { - for (size_t i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], fname_inp.c_str()); + return 3; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str()); + return 4; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", argv[0], fname_inp.c_str()); + return 5; + } + + int n = wav.totalPCMFrameCount; + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (int i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (int i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } } } - } - // print some info about the processing - { - printf("\n"); - if (!whisper_is_multilingual(ctx)) { - if (params.language != "en" || params.translate) { - params.language = "en"; - params.translate = false; - printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); - } - } - printf("%s: processing %d samples (%.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", - __func__, int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, - params.language.c_str(), - params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1); - printf("\n"); - } - - // run the inference - { - whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); - - wparams.print_realtime = true; - wparams.print_progress = false; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special_tokens = params.print_special_tokens; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 6; - } - - // print result; - if (!wparams.print_realtime) { + // print some info about the processing + { printf("\n"); + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + printf("%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + printf("%s: processing '%s' (%d samples, %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1); + printf("\n"); + } - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); + // run the inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); - if (params.no_timestamps) { - printf ("%s", text); - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + wparams.print_realtime = true; + wparams.print_progress = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special_tokens = params.print_special_tokens; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; - printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 6; + } + + // print result; + if (!wparams.print_realtime) { + printf("\n"); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + + if (params.no_timestamps) { + printf ("%s", text); + fflush(stdout); + } else { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); + } } } } diff --git a/stream.cpp b/stream.cpp index d0e40c20..1f84d667 100644 --- a/stream.cpp +++ b/stream.cpp @@ -238,7 +238,7 @@ int main(int argc, char ** argv) { } // process 3 seconds of new audio - while ((int) SDL_GetQueuedAudioSize(g_dev_id_in) < 3*WHISPER_SAMPLE_RATE*sizeof(float)) { + while (SDL_GetQueuedAudioSize(g_dev_id_in) < 3*WHISPER_SAMPLE_RATE*sizeof(float)) { SDL_Delay(1); } const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float); @@ -265,6 +265,11 @@ int main(int argc, char ** argv) { wparams.print_progress = false; wparams.print_special_tokens = params.print_special_tokens; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); diff --git a/whisper.cpp b/whisper.cpp index 59a72724..1feff5e1 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1034,8 +1034,6 @@ bool whisper_encode( const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; - const int n_ctx = hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; @@ -1296,7 +1294,8 @@ bool whisper_encode( struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF); { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, inpO); ggml_graph_compute (ctxL, &gf); @@ -1332,7 +1331,8 @@ bool whisper_encode( // run the computation { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); ggml_graph_compute (ctx0, &gf); @@ -1356,7 +1356,8 @@ bool whisper_encode( // pre-compute cross-attention memory { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; // TODO: hack to disconnect the encoded features from the previous graph cur->op = GGML_OP_NONE; @@ -1466,7 +1467,8 @@ bool whisper_decode( }; struct ggml_context * ctxL = ggml_init(paramsL); - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; // norm { @@ -1749,7 +1751,8 @@ bool whisper_decode( // run the computation { - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf = {}; + gf.n_threads = n_threads; ggml_build_forward_expand(&gf, cur); ggml_graph_compute (ctx0, &gf); @@ -2283,7 +2286,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat #if defined(_MSC_VER) result = { #else - result = (struct whisper_full_params) { + result = (struct whisper_full_params){ #endif .strategy = WHISPER_DECODE_GREEDY, .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()), @@ -2351,7 +2354,7 @@ int whisper_full( } } - if (seek >= whisper_n_len(ctx)) { + if (seek + 100 >= whisper_n_len(ctx)) { break; } @@ -2380,7 +2383,6 @@ int whisper_full( bool done = false; int seek_delta = 100*WHISPER_CHUNK_SIZE; - whisper_token last_id = 0; // print the prompt //printf("\n\n"); @@ -2410,8 +2412,6 @@ int whisper_full( // feel free to experiment! // { - const int n_vocab = whisper_n_vocab(ctx); - whisper_token id = 0; whisper_token tid = whisper_token_beg(ctx); @@ -2425,7 +2425,6 @@ int whisper_full( seek_delta = 2*(id - whisper_token_beg(ctx)); result_len = i + 1; } - last_id = id; // add it to the context prompt.push_back(id); @@ -2459,7 +2458,7 @@ int whisper_full( std::string text = ""; - for (int i = 0; i < result_cur.size(); i++) { + for (int i = 0; i < (int) result_cur.size(); i++) { if (params.print_special_tokens == false && result_cur[i].id >= whisper_token_eot(ctx)) { } else { text += whisper_token_to_str(ctx, result_cur[i].id); @@ -2479,7 +2478,7 @@ int whisper_full( result_all.push_back({ t0, t1, text }); } text = ""; - while (result_cur[i].id > whisper_token_beg(ctx) && i < result_cur.size()) { + while (result_cur[i].id > whisper_token_beg(ctx) && i < (int) result_cur.size()) { i++; } i--;