mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-25 01:19:10 +00:00
Compare commits
26 Commits
Author | SHA1 | Date | |
---|---|---|---|
1d716d6e34 | |||
419b8a6402 | |||
1eb81f863f | |||
fba10a4c68 | |||
afe2db0fe2 | |||
a7047b2a28 | |||
32fbc8cd04 | |||
b8065d90f5 | |||
4312995974 | |||
5eeeb3412d | |||
6a69e3ae27 | |||
bf69b669a0 | |||
ea19ed33f1 | |||
675e787171 | |||
c6c3ad5a98 | |||
6a7c82501e | |||
a82d331034 | |||
c37c2443c1 | |||
0f11759406 | |||
5a5c5ddcca | |||
34e0b4b9ef | |||
b0f8013eb9 | |||
124c718c73 | |||
f66ac6dc4f | |||
9955fa4ed7 | |||
a613f16aec |
56
.github/workflows/build.yml
vendored
56
.github/workflows/build.yml
vendored
@ -119,7 +119,59 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [RelWithDebInfo]
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v1
|
||||
|
||||
- name: Add msbuild to PATH
|
||||
uses: microsoft/setup-msbuild@v1
|
||||
|
||||
- name: Fetch SDL2 and set SDL2_DIR
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: |
|
||||
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
|
||||
7z x sdl2.zip
|
||||
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Configure
|
||||
run: >
|
||||
cmake -S . -B ./build -A ${{ matrix.arch }}
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
-DWHISPER_SUPPORT_SDL2=${{ matrix.sdl2 }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Upload binaries
|
||||
if: matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
||||
windows-blas:
|
||||
runs-on: windows-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [Release]
|
||||
arch: [Win32, x64]
|
||||
blas: [ON]
|
||||
sdl2: [ON]
|
||||
@ -181,5 +233,5 @@ jobs:
|
||||
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
|
||||
uses: actions/upload-artifact@v1
|
||||
with:
|
||||
name: whisper-bin-${{ matrix.arch }}
|
||||
name: whisper-blas-bin-${{ matrix.arch }}
|
||||
path: build/bin/${{ matrix.build }}
|
||||
|
@ -1,5 +1,6 @@
|
||||
cmake_minimum_required (VERSION 3.0)
|
||||
project(whisper.cpp VERSION 1.0.3)
|
||||
|
||||
project(whisper.cpp VERSION 1.0.4)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS "on")
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
@ -4,6 +4,8 @@
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
[Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
@ -19,11 +21,11 @@ Supported platforms:
|
||||
|
||||
- [x] Mac OS (Intel and Arm)
|
||||
- [x] [iOS](examples/whisper.objc)
|
||||
- [x] Linux
|
||||
- [x] [Android](examples/whisper.android)
|
||||
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
|
||||
- [x] [WebAssembly](examples/whisper.wasm)
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
- [x] [Android](https://github.com/ggerganov/whisper.cpp/issues/30)
|
||||
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
@ -465,6 +467,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
|
||||
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
|
||||
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
|
||||
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
|
||||
| [whisper.nvim](examples/whisper.nvim) | | Speech-to-text plugin for Neovim |
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
|
Submodule bindings/ios updated: dd58b25d84...1502317fe0
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.0.3",
|
||||
"version": "1.0.4",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
@ -8,7 +8,13 @@ More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/
|
||||
./command -m ./models/ggml-small.en.bin -t 8
|
||||
|
||||
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 4 -c 0
|
||||
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
|
||||
|
||||
# Run in guided mode, the list of allowed commands is in commands.txt
|
||||
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
|
||||
|
||||
# On Raspberry Pi, in guided mode you can use "-ac 128" for extra performance
|
||||
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
|
||||
```
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
@ -41,6 +42,7 @@ struct whisper_params {
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string fname_out = "";
|
||||
std::string commands = "";
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -68,6 +70,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -83,22 +86,23 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
||||
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
|
||||
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -484,6 +488,28 @@ float similarity(const std::string & s0, const std::string & s1) {
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
||||
std::vector<std::string> allowed_commands;
|
||||
|
||||
std::ifstream ifs(fname);
|
||||
if (!ifs.is_open()) {
|
||||
return allowed_commands;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
line = trim(line);
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::transform(line.begin(), line.end(),line.begin(), ::tolower);
|
||||
allowed_commands.push_back(std::move(line));
|
||||
}
|
||||
|
||||
return allowed_commands;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -521,7 +547,6 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// init audio
|
||||
|
||||
audio_async audio(30*1000);
|
||||
@ -532,6 +557,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.resume();
|
||||
|
||||
// wait for 1 second to avoid any buffered noise
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
audio.clear();
|
||||
|
||||
int max_len = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
@ -542,7 +573,94 @@ int main(int argc, char ** argv) {
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
std::vector<std::string> allowed_commands;
|
||||
std::vector<std::vector<whisper_token>> allowed_tokens;
|
||||
|
||||
std::string k_prompt = "";
|
||||
std::vector<whisper_token> k_tokens;
|
||||
|
||||
if (params.commands != "") {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: guided mode\n", __func__);
|
||||
|
||||
allowed_commands = read_allowed_commands(params.commands);
|
||||
|
||||
if (allowed_commands.empty()) {
|
||||
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
||||
return 2;
|
||||
}
|
||||
|
||||
for (const auto & cmd : allowed_commands) {
|
||||
whisper_token tokens[1024];
|
||||
allowed_tokens.emplace_back();
|
||||
|
||||
for (int l = 0; l < cmd.size(); ++l) {
|
||||
// NOTE: very important to add the whitespace !
|
||||
// the reason is that the first decoded token starts with a whitespace too!
|
||||
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
||||
|
||||
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
allowed_tokens.back().push_back(tokens[0]);
|
||||
}
|
||||
}
|
||||
|
||||
max_len = std::max(max_len, (int) cmd.size());
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
||||
for (const auto & token : allowed_tokens[i]) {
|
||||
fprintf(stderr, " %5d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
}
|
||||
k_prompt += allowed_commands[i];
|
||||
}
|
||||
k_prompt += ". selected word: ";
|
||||
|
||||
// tokenize prompt
|
||||
{
|
||||
k_tokens.resize(1024);
|
||||
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
||||
if (n < 0) {
|
||||
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
||||
return 4;
|
||||
}
|
||||
k_tokens.resize(n);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
||||
fprintf(stderr, "%s: tokens: [", __func__);
|
||||
for (const auto & token : k_tokens) {
|
||||
fprintf(stderr, " %d", token);
|
||||
}
|
||||
fprintf(stderr, " ]\n");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
} else {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
|
||||
k_prompt = "Ok Whisper, start listening for commands.";
|
||||
}
|
||||
|
||||
// main loop
|
||||
while (is_running) {
|
||||
@ -568,78 +686,172 @@ int main(int argc, char ** argv) {
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
if (allowed_commands.empty()) {
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
int64_t t_ms = 0;
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
{
|
||||
int64_t t_ms = 0;
|
||||
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
if (!have_prompt) {
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// command-list mode
|
||||
// guide the transcription to match the most likely command from a provided list
|
||||
|
||||
{
|
||||
audio.get(2000, pcmf32_cur);
|
||||
|
||||
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
|
||||
if (!have_prompt) {
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = 1;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
wparams.prompt_tokens = k_tokens.data();
|
||||
wparams.prompt_n_tokens = k_tokens.size();
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
// run the transformer and a single decoding pass
|
||||
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
||||
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
const auto * probs = whisper_get_probs(ctx);
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
|
||||
double psum = 0.0;
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
probs_id.push_back(std::make_pair(probs[allowed_tokens[i][0]], i));
|
||||
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
||||
probs_id.back().first += probs[allowed_tokens[i][j]];
|
||||
}
|
||||
} else {
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
probs_id.back().first /= allowed_tokens[i].size();
|
||||
psum += probs_id.back().first;
|
||||
}
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
// normalize
|
||||
for (auto & p : probs_id) {
|
||||
p.first /= psum;
|
||||
}
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
// print the commands and the respective probabilities
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
for (const auto & cmd : probs_id) {
|
||||
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
||||
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
|
||||
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
|
||||
}
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
// best command
|
||||
{
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
||||
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
|
||||
(int) std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::high_resolution_clock::now() - t_start).count());
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
audio.clear();
|
||||
}
|
||||
}
|
||||
|
9
examples/command/commands.txt
Normal file
9
examples/command/commands.txt
Normal file
@ -0,0 +1,9 @@
|
||||
enable
|
||||
disable
|
||||
cat
|
||||
dog
|
||||
apple
|
||||
red
|
||||
blue
|
||||
green
|
||||
lightblue
|
@ -62,19 +62,21 @@ struct whisper_params {
|
||||
|
||||
float word_thold = 0.01f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool no_timestamps = false;
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool diarize = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
};
|
||||
@ -94,27 +96,29 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -130,28 +134,30 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
|
||||
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -447,7 +453,7 @@ int main(int argc, char ** argv) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
@ -462,6 +468,22 @@ int main(int argc, char ** argv) {
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initial prompt
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
|
||||
if (params.prompt.size() > 0) {
|
||||
prompt_tokens.resize(1024);
|
||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
||||
fprintf(stderr, "initial tokens: [ ");
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
||||
}
|
||||
fprintf(stderr, "]\n");
|
||||
}
|
||||
|
||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||
const auto fname_inp = params.fname_inp[f];
|
||||
|
||||
@ -577,13 +599,12 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
@ -599,6 +620,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data();
|
||||
wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size();
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||
|
||||
// this callback is called on each new segment
|
||||
|
@ -10,6 +10,23 @@ More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/i
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
|
||||
|
||||
## Sliding window mode with VAD
|
||||
|
||||
Setting the `--step` argument to `0` enables the sliding window mode:
|
||||
|
||||
```java
|
||||
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
|
||||
```
|
||||
|
||||
In this mode, the tool will transcribe only after some speech activity is detected. A very
|
||||
basic VAD detector is used, but in theory a more sophisticated approach can be added. The
|
||||
`-vth` argument determines the VAD threshold - higher values will make it detect silence more often.
|
||||
It's best to tune it to the specific use case, but a value around `0.6` should be OK in general.
|
||||
When silence is detected, it will transcribe the last `--length` milliseconds of audio and output
|
||||
a transcription block that is suitable for parsing.
|
||||
|
||||
## Building
|
||||
|
||||
The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
||||
|
||||
```bash
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Real-time speech recognition of input from a microphone
|
||||
//
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "whisper.h"
|
||||
|
||||
@ -13,6 +14,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
@ -33,15 +35,19 @@ struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t step_ms = 3000;
|
||||
int32_t length_ms = 10000;
|
||||
int32_t keep_ms = 200;
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool no_context = true;
|
||||
bool print_special = false;
|
||||
bool no_timestamps = true;
|
||||
bool no_context = true;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
@ -61,13 +67,16 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
@ -90,13 +99,16 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
|
||||
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
|
||||
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
@ -107,19 +119,56 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||
// SDL Audio capture
|
||||
//
|
||||
|
||||
SDL_AudioDeviceID g_dev_id_in = 0;
|
||||
class audio_async {
|
||||
public:
|
||||
audio_async(int len_ms);
|
||||
~audio_async();
|
||||
|
||||
bool audio_sdl_init(const int capture_id) {
|
||||
if (g_dev_id_in) {
|
||||
fprintf(stderr, "%s: already initialized\n", __func__);
|
||||
return false;
|
||||
bool init(int capture_id, int sample_rate);
|
||||
|
||||
// start capturing audio via the provided SDL callback
|
||||
// keep last len_ms seconds of audio in a circular buffer
|
||||
bool resume();
|
||||
bool pause();
|
||||
bool clear();
|
||||
|
||||
// callback to be called by SDL
|
||||
void callback(uint8_t * stream, int len);
|
||||
|
||||
// get audio data from the circular buffer
|
||||
void get(int ms, std::vector<float> & audio);
|
||||
|
||||
private:
|
||||
SDL_AudioDeviceID m_dev_id_in = 0;
|
||||
|
||||
int m_len_ms = 0;
|
||||
int m_sample_rate = 0;
|
||||
|
||||
bool m_running = false;
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
||||
audio_async::audio_async(int len_ms) {
|
||||
m_len_ms = len_ms;
|
||||
}
|
||||
|
||||
audio_async::~audio_async() {
|
||||
if (m_dev_id_in) {
|
||||
SDL_CloseAudioDevice(m_dev_id_in);
|
||||
}
|
||||
}
|
||||
|
||||
bool audio_async::init(int capture_id, int sample_rate) {
|
||||
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
|
||||
|
||||
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
|
||||
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
|
||||
return (1);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
|
||||
@ -138,34 +187,232 @@ bool audio_sdl_init(const int capture_id) {
|
||||
SDL_zero(capture_spec_requested);
|
||||
SDL_zero(capture_spec_obtained);
|
||||
|
||||
capture_spec_requested.freq = WHISPER_SAMPLE_RATE;
|
||||
capture_spec_requested.freq = sample_rate;
|
||||
capture_spec_requested.format = AUDIO_F32;
|
||||
capture_spec_requested.channels = 1;
|
||||
capture_spec_requested.samples = 1024;
|
||||
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
|
||||
audio_async * audio = (audio_async *) userdata;
|
||||
audio->callback(stream, len);
|
||||
};
|
||||
capture_spec_requested.userdata = this;
|
||||
|
||||
if (capture_id >= 0) {
|
||||
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
|
||||
g_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
|
||||
g_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
|
||||
}
|
||||
if (!g_dev_id_in) {
|
||||
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
|
||||
g_dev_id_in = 0;
|
||||
m_dev_id_in = 0;
|
||||
|
||||
return false;
|
||||
} else {
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, g_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format, capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels, capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
|
||||
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
|
||||
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
|
||||
capture_spec_requested.format);
|
||||
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
|
||||
capture_spec_requested.channels);
|
||||
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
|
||||
}
|
||||
|
||||
m_sample_rate = capture_spec_obtained.freq;
|
||||
|
||||
m_audio.resize((m_sample_rate*m_len_ms)/1000);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::resume() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_running) {
|
||||
fprintf(stderr, "%s: already running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 0);
|
||||
|
||||
m_running = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::pause() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: already paused!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(m_dev_id_in, 1);
|
||||
|
||||
m_running = false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool audio_async::clear() {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
m_audio_pos = 0;
|
||||
m_audio_len = 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// callback to be called by SDL
|
||||
void audio_async::callback(uint8_t * stream, int len) {
|
||||
if (!m_running) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (m_audio_pos + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
} else {
|
||||
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void audio_async::get(int ms, std::vector<float> & result) {
|
||||
if (!m_dev_id_in) {
|
||||
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!m_running) {
|
||||
fprintf(stderr, "%s: not running!\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
|
||||
if (ms <= 0) {
|
||||
ms = m_len_ms;
|
||||
}
|
||||
|
||||
size_t n_samples = (m_sample_rate * ms) / 1000;
|
||||
if (n_samples > m_audio_len) {
|
||||
n_samples = m_audio_len;
|
||||
}
|
||||
|
||||
result.resize(n_samples);
|
||||
|
||||
int s0 = m_audio_pos - n_samples;
|
||||
if (s0 < 0) {
|
||||
s0 += m_audio.size();
|
||||
}
|
||||
|
||||
if (s0 + n_samples > m_audio.size()) {
|
||||
const size_t n0 = m_audio.size() - s0;
|
||||
|
||||
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
|
||||
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
|
||||
} else {
|
||||
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////
|
||||
|
||||
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||||
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||||
const float dt = 1.0f / sample_rate;
|
||||
const float alpha = dt / (rc + dt);
|
||||
|
||||
float y = data[0];
|
||||
|
||||
for (size_t i = 1; i < data.size(); i++) {
|
||||
y = alpha * (y + data[i] - data[i - 1]);
|
||||
data[i] = y;
|
||||
}
|
||||
}
|
||||
|
||||
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||||
const int n_samples = pcmf32.size();
|
||||
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||||
|
||||
if (n_samples_last >= n_samples) {
|
||||
// not enough samples - assume no speech
|
||||
return false;
|
||||
}
|
||||
|
||||
if (freq_thold > 0.0f) {
|
||||
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||||
}
|
||||
|
||||
float energy_all = 0.0f;
|
||||
float energy_last = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < n_samples; i++) {
|
||||
energy_all += fabsf(pcmf32[i]);
|
||||
|
||||
if (i >= n_samples - n_samples_last) {
|
||||
energy_last += fabsf(pcmf32[i]);
|
||||
}
|
||||
}
|
||||
|
||||
energy_all /= n_samples;
|
||||
energy_last /= n_samples_last;
|
||||
|
||||
if (verbose) {
|
||||
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||||
}
|
||||
|
||||
if (energy_last > vad_thold*energy_all) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
|
||||
@ -173,33 +420,46 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.keep_ms = std::min(params.keep_ms, params.step_ms); // cannot be more than step_ms
|
||||
|
||||
const int n_samples_step = (params.step_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms*1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = (params.keep_ms *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = (30000 *1e-3)*WHISPER_SAMPLE_RATE;
|
||||
|
||||
const int n_new_line = params.length_ms / params.step_ms - 1; // number of steps to print new line
|
||||
|
||||
const bool use_vad = n_samples_step <= 0; // sliding window mode uses VAD
|
||||
|
||||
params.no_timestamps = !use_vad;
|
||||
params.no_context = use_vad;
|
||||
params.max_tokens = 0;
|
||||
|
||||
// init audio
|
||||
|
||||
if (!audio_sdl_init(params.capture_id)) {
|
||||
fprintf(stderr, "%s: audio_sdl_init() failed!\n", __func__);
|
||||
audio_async audio(params.length_ms);
|
||||
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
||||
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
audio.resume();
|
||||
|
||||
// whisper init
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
||||
|
||||
const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_30s = 30*WHISPER_SAMPLE_RATE;
|
||||
const int n_samples_keep = 0.2*WHISPER_SAMPLE_RATE;
|
||||
|
||||
std::vector<float> pcmf32(n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old;
|
||||
std::vector<float> pcmf32 (n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
|
||||
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
|
||||
|
||||
std::vector<whisper_token> prompt_tokens;
|
||||
const int n_new_line = params.length_ms / params.step_ms - 1;
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
@ -211,23 +471,28 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing %d samples (step = %.1f sec / len = %.1f sec / keep = %.1f sec), %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
__func__,
|
||||
n_samples,
|
||||
float(n_samples)/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_len)/WHISPER_SAMPLE_RATE,
|
||||
n_samples_step,
|
||||
float(n_samples_step)/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_len )/WHISPER_SAMPLE_RATE,
|
||||
float(n_samples_keep)/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
|
||||
if (!use_vad) {
|
||||
fprintf(stderr, "%s: n_new_line = %d\n", __func__, n_new_line);
|
||||
} else {
|
||||
fprintf(stderr, "%s: using VAD, will transcribe on speech activity\n", __func__);
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
SDL_PauseAudioDevice(g_dev_id_in, 0);
|
||||
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
|
||||
std::ofstream fout;
|
||||
@ -242,6 +507,9 @@ int main(int argc, char ** argv) {
|
||||
printf("[Start speaking]");
|
||||
fflush(stdout);
|
||||
|
||||
auto t_last = std::chrono::high_resolution_clock::now();
|
||||
const auto t_start = t_last;
|
||||
|
||||
// main audio loop
|
||||
while (is_running) {
|
||||
// handle Ctrl + C
|
||||
@ -268,35 +536,64 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// process new audio
|
||||
if (n_iter > 0 && SDL_GetQueuedAudioSize(g_dev_id_in) > 2*n_samples*sizeof(float)) {
|
||||
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
|
||||
SDL_ClearQueuedAudio(g_dev_id_in);
|
||||
|
||||
if (!use_vad) {
|
||||
while (true) {
|
||||
audio.get(params.step_ms, pcmf32_new);
|
||||
|
||||
if ((int) pcmf32_new.size() > 2*n_samples_step) {
|
||||
fprintf(stderr, "\n\n%s: WARNING: cannot process audio fast enough, dropping audio ...\n\n", __func__);
|
||||
audio.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((int) pcmf32_new.size() >= n_samples_step) {
|
||||
audio.clear();
|
||||
break;
|
||||
}
|
||||
|
||||
SDL_Delay(1);
|
||||
}
|
||||
|
||||
const int n_samples_new = pcmf32_new.size();
|
||||
|
||||
// take up to params.length_ms audio from previous iteration
|
||||
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
|
||||
|
||||
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
|
||||
|
||||
pcmf32.resize(n_samples_new + n_samples_take);
|
||||
|
||||
for (int i = 0; i < n_samples_take; i++) {
|
||||
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
|
||||
}
|
||||
|
||||
memcpy(pcmf32.data() + n_samples_take, pcmf32_new.data(), n_samples_new*sizeof(float));
|
||||
|
||||
pcmf32_old = pcmf32;
|
||||
} else {
|
||||
const auto t_now = std::chrono::high_resolution_clock::now();
|
||||
const auto t_diff = std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count();
|
||||
|
||||
if (t_diff < 2000) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
audio.get(2000, pcmf32_new);
|
||||
|
||||
if (vad_simple(pcmf32_new, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, false)) {
|
||||
audio.get(params.length_ms, pcmf32);
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
t_last = t_now;
|
||||
}
|
||||
|
||||
while (SDL_GetQueuedAudioSize(g_dev_id_in) < n_samples*sizeof(float)) {
|
||||
SDL_Delay(1);
|
||||
}
|
||||
|
||||
const int n_samples_new = SDL_GetQueuedAudioSize(g_dev_id_in)/sizeof(float);
|
||||
|
||||
// take one second from previous iteration
|
||||
//const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new));
|
||||
|
||||
// take up to params.length_ms audio from previous iteration
|
||||
const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new));
|
||||
|
||||
//printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size());
|
||||
|
||||
pcmf32.resize(n_samples_new + n_samples_take);
|
||||
|
||||
for (int i = 0; i < n_samples_take; i++) {
|
||||
pcmf32[i] = pcmf32_old[pcmf32_old.size() - n_samples_take + i];
|
||||
}
|
||||
|
||||
SDL_DequeueAudio(g_dev_id_in, pcmf32.data() + n_samples_take, n_samples_new*sizeof(float));
|
||||
|
||||
pcmf32_old = pcmf32;
|
||||
|
||||
// run the inference
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
@ -307,7 +604,7 @@ int main(int argc, char ** argv) {
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.single_segment = true;
|
||||
wparams.single_segment = !use_vad;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
@ -325,12 +622,21 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// print result;
|
||||
{
|
||||
printf("\33[2K\r");
|
||||
if (!use_vad) {
|
||||
printf("\33[2K\r");
|
||||
|
||||
// print long empty line to clear the previous line
|
||||
printf("%s", std::string(100, ' ').c_str());
|
||||
// print long empty line to clear the previous line
|
||||
printf("%s", std::string(100, ' ').c_str());
|
||||
|
||||
printf("\33[2K\r");
|
||||
printf("\33[2K\r");
|
||||
} else {
|
||||
const int64_t t1 = (t_last - t_start).count()/1000000;
|
||||
const int64_t t0 = std::max(0.0, t1 - pcmf32.size()*1000.0/WHISPER_SAMPLE_RATE);
|
||||
|
||||
printf("\n");
|
||||
printf("### Transcription %d START | t0 = %d ms | t1 = %d ms\n", n_iter, (int) t0, (int) t1);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
@ -358,11 +664,16 @@ int main(int argc, char ** argv) {
|
||||
if (params.fname_out.length() > 0) {
|
||||
fout << std::endl;
|
||||
}
|
||||
|
||||
if (use_vad){
|
||||
printf("\n");
|
||||
printf("### Transcription %d END\n", n_iter);
|
||||
}
|
||||
}
|
||||
|
||||
++n_iter;
|
||||
|
||||
if ((n_iter % n_new_line) == 0) {
|
||||
if (!use_vad && (n_iter % n_new_line) == 0) {
|
||||
printf("\n");
|
||||
|
||||
// keep part of the audio for next iteration to try to mitigate word boundary issues
|
||||
@ -384,9 +695,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (g_dev_id_in >= 0) {
|
||||
SDL_CloseAudioDevice(g_dev_id_in);
|
||||
}
|
||||
audio.pause();
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
@ -31,7 +31,7 @@ To run this, you will need a ggml GPT-2 model: [instructions](https://github.com
|
||||
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
|
||||
|
||||
```
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://ggml.ggerganov.com/ggml-model-gpt-2-117M.bin
|
||||
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/datasets/ggerganov/ggml/raw/main/ggml-model-gpt-2-117M.bin
|
||||
```
|
||||
|
||||
## TTS
|
||||
|
@ -139,7 +139,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int)logits_id.size(); i++) {
|
||||
//for (int i = 0; i < (int) logits_id.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), logits_id[i].first);
|
||||
//}
|
||||
//exit(0);
|
||||
@ -825,8 +825,8 @@ Me too.
|
||||
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
|
||||
|
||||
// sampling parameters
|
||||
int32_t top_k = 20;
|
||||
float top_p = 0.98f;
|
||||
int32_t top_k = 5;
|
||||
float top_p = 0.9f;
|
||||
float temp = 1.0f;
|
||||
};
|
||||
|
||||
@ -840,7 +840,7 @@ struct gpt2_context * gpt2_init(const char * path_model) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, "gpt-2.bin");
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -913,10 +913,7 @@ std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens)
|
||||
result += ctx->vocab.id_to_token[embd[0]];
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256 ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "." ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "!" ||
|
||||
ctx->vocab.id_to_token[embd.back()] == "?") {
|
||||
if (embd.back() == 50256) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -473,56 +473,15 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
return result;
|
||||
}
|
||||
|
||||
// compute similarity between two strings using Levenshtein distance
|
||||
float similarity(const std::string & s0, const std::string & s1) {
|
||||
const size_t len0 = s0.size() + 1;
|
||||
const size_t len1 = s1.size() + 1;
|
||||
const std::string k_prompt =
|
||||
R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
|
||||
|
||||
std::vector<int> col(len1, 0);
|
||||
std::vector<int> prevCol(len1, 0);
|
||||
B: Hello {0}, how are you?
|
||||
A: I'm fine, thank you.
|
||||
{1}
|
||||
Here is how {0} (A) continues the dialogue:
|
||||
|
||||
for (size_t i = 0; i < len1; i++) {
|
||||
prevCol[i] = i;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len0; i++) {
|
||||
col[0] = i;
|
||||
for (size_t j = 1; j < len1; j++) {
|
||||
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (s0[i - 1] == s1[j - 1] ? 0 : 1));
|
||||
}
|
||||
col.swap(prevCol);
|
||||
}
|
||||
|
||||
const float dist = prevCol[len1 - 1];
|
||||
|
||||
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
||||
}
|
||||
|
||||
// generated with ChatGPT
|
||||
std::map<std::string, std::string> k_prompts = {
|
||||
{ "Santa",
|
||||
R"(Kid: Hi Santa! Are you real?
|
||||
Santa: Of course I am, my dear! Ho ho ho!
|
||||
Kid: Can you please bring me a new toy for Christmas?
|
||||
Santa: I'll see what I can do, but you have to make sure to be a good boy or girl and listen to your parents.
|
||||
Kid: I will, Santa! Thank you!
|
||||
Santa: You're welcome, little one. Merry Christmas! Ho ho ho!
|
||||
Kid: Can you tell me how you deliver all the presents to all the kids in the world in one night?
|
||||
Santa: It's a secret, but I have a lot of help from my elves and my magical sleigh. And I have a special route that I follow to make sure I visit every child.
|
||||
Kid: Wow, that's amazing! Can I please have a ride in your sleigh sometime?
|
||||
Santa: I'm sorry, but only good boys and girls get to ride in my sleigh.
|
||||
)" },
|
||||
{ "Kid",
|
||||
R"(Kid: Hi Santa! Are you real?
|
||||
Santa: Of course I am, my dear! Ho ho ho!
|
||||
Kid: Can you please bring me a new toy for Christmas?
|
||||
Santa: I'll see what I can do, but you have to make sure to be a good boy or girl and listen to your parents.
|
||||
Kid: I will, Santa! Thank you!
|
||||
Kid: Can you tell me how you deliver all the presents to all the kids in the world in one night?
|
||||
Santa: It's a secret, but I have a lot of help from my elves and my magical sleigh. And I have a special route that I follow to make sure I visit every child.
|
||||
Kid: Wow, that's amazing! Can I please have a ride in your sleigh sometime?
|
||||
)" },
|
||||
};
|
||||
A:)";
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
@ -579,7 +538,7 @@ int main(int argc, char ** argv) {
|
||||
int n_iter = 0;
|
||||
|
||||
bool is_running = true;
|
||||
bool force_speak = params.person == "Kid";
|
||||
bool force_speak = false;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
@ -587,19 +546,13 @@ int main(int argc, char ** argv) {
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
if (k_prompts.find(params.person) == k_prompts.end()) {
|
||||
fprintf(stderr, "%s: unknown person '%s'\n", __func__, params.person.c_str());
|
||||
return 1;
|
||||
}
|
||||
gpt2_set_prompt(ctx_gpt, "");
|
||||
|
||||
gpt2_set_prompt(ctx_gpt, k_prompts.at(params.person).c_str());
|
||||
const int voice_id = rand()%6;
|
||||
|
||||
const std::string person_other = params.person == "Santa" ? "Kid" : "Santa";
|
||||
const int voice_id = params.person == "Santa" ? 5 : 2;
|
||||
|
||||
fprintf(stderr, "gpt-2: prompt_base:\n");
|
||||
fprintf(stderr, "gpt-2: prompt:\n");
|
||||
fprintf(stderr, "========================\n\n");
|
||||
fprintf(stderr, "%s\n", gpt2_get_prompt(ctx_gpt));
|
||||
fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
|
||||
fprintf(stderr, "========================\n\n");
|
||||
|
||||
// main loop
|
||||
@ -636,13 +589,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
audio.get(params.voice_ms, pcmf32_cur);
|
||||
|
||||
std::string text_heard = "Hey little one, what do you want for Christmas?";
|
||||
std::string text_heard = "";
|
||||
|
||||
if (!force_speak) {
|
||||
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
|
||||
}
|
||||
|
||||
force_speak = false;
|
||||
|
||||
// remove text between brackets using regex
|
||||
{
|
||||
std::regex re("\\[.*?\\]");
|
||||
@ -667,13 +619,15 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
|
||||
|
||||
if (text_heard.empty() || tokens.empty()) {
|
||||
if (text_heard.empty() || tokens.empty() || force_speak) {
|
||||
fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
|
||||
audio.clear();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
force_speak = false;
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
std::string prompt_base = gpt2_get_prompt(ctx_gpt);
|
||||
@ -681,9 +635,11 @@ int main(int argc, char ** argv) {
|
||||
std::string text_to_speak;
|
||||
|
||||
{
|
||||
text_heard = person_other + ": " + text_heard;
|
||||
prompt_base += "B: " + text_heard + "\n";
|
||||
|
||||
text_to_speak = gpt2_gen_text(ctx_gpt, (prompt_base + text_heard + "\n").c_str(), params.max_tokens);
|
||||
std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
|
||||
|
||||
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
|
||||
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
||||
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
|
||||
|
||||
@ -703,13 +659,20 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
prompt_base += text_heard + "\n" + text_to_speak + "\n";
|
||||
prompt_base += "A:" + text_to_speak + "\n";
|
||||
|
||||
{
|
||||
prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
|
||||
|
||||
printf("===============\n");
|
||||
printf("prompt:\n");
|
||||
printf("%s\n", prompt.c_str());
|
||||
printf("===============\n");
|
||||
}
|
||||
}
|
||||
|
||||
printf("%s\n", text_to_speak.c_str());
|
||||
|
||||
//printf("========================\n");
|
||||
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
|
||||
//printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
|
||||
//printf("========================\n");
|
||||
|
||||
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
|
||||
|
15
examples/whisper.android/.gitignore
vendored
Normal file
15
examples/whisper.android/.gitignore
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
3
examples/whisper.android/.idea/.gitignore
generated
vendored
Normal file
3
examples/whisper.android/.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
1
examples/whisper.android/.idea/.name
generated
Normal file
1
examples/whisper.android/.idea/.name
generated
Normal file
@ -0,0 +1 @@
|
||||
WhisperCppDemo
|
6
examples/whisper.android/.idea/compiler.xml
generated
Normal file
6
examples/whisper.android/.idea/compiler.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CompilerConfiguration">
|
||||
<bytecodeTargetLevel target="11" />
|
||||
</component>
|
||||
</project>
|
18
examples/whisper.android/.idea/gradle.xml
generated
Normal file
18
examples/whisper.android/.idea/gradle.xml
generated
Normal file
@ -0,0 +1,18 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="distributionType" value="DEFAULT_WRAPPED" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/app" />
|
||||
</set>
|
||||
</option>
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
10
examples/whisper.android/.idea/misc.xml
generated
Normal file
10
examples/whisper.android/.idea/misc.xml
generated
Normal file
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="true" project-jdk-name="Android Studio default JDK" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||
</component>
|
||||
<component name="ProjectType">
|
||||
<option name="id" value="Android" />
|
||||
</component>
|
||||
</project>
|
6
examples/whisper.android/.idea/vcs.xml
generated
Normal file
6
examples/whisper.android/.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
12
examples/whisper.android/README.md
Normal file
12
examples/whisper.android/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
A sample Android app using [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
<img width="300" alt="image" src="https://user-images.githubusercontent.com/1991296/208154256-82d972dc-221b-48c4-bfcb-36ce68602f93.png">
|
1
examples/whisper.android/app/.gitignore
vendored
Normal file
1
examples/whisper.android/app/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/build
|
76
examples/whisper.android/app/build.gradle
Normal file
76
examples/whisper.android/app/build.gradle
Normal file
@ -0,0 +1,76 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
id 'org.jetbrains.kotlin.android'
|
||||
}
|
||||
|
||||
android {
|
||||
namespace 'com.whispercppdemo'
|
||||
compileSdk 33
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.whispercppdemo"
|
||||
minSdk 26
|
||||
targetSdk 32
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
ndk {
|
||||
abiFilters 'arm64-v8a', 'x86_64'
|
||||
}
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
vectorDrawables {
|
||||
useSupportLibrary true
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.debug
|
||||
minifyEnabled true
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
}
|
||||
buildFeatures {
|
||||
compose true
|
||||
}
|
||||
composeOptions {
|
||||
kotlinCompilerExtensionVersion '1.3.1'
|
||||
}
|
||||
ndkVersion "25.0.8528842"
|
||||
externalNativeBuild {
|
||||
ndkBuild {
|
||||
path 'src/main/jni/whisper/Android.mk'
|
||||
}
|
||||
}
|
||||
packagingOptions {
|
||||
resources {
|
||||
excludes += '/META-INF/{AL2.0,LGPL2.1}'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.activity:activity-compose:1.6.1'
|
||||
implementation 'androidx.compose.material:material-icons-core:1.3.1'
|
||||
implementation 'androidx.compose.material3:material3:1.0.1'
|
||||
implementation "androidx.compose.ui:ui:1.3.2"
|
||||
implementation "androidx.compose.ui:ui-tooling-preview:1.3.2"
|
||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.5.1'
|
||||
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4'
|
||||
|
||||
testImplementation 'junit:junit:4.13.2'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
||||
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-tooling:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-test-manifest:1.3.2"
|
||||
}
|
21
examples/whisper.android/app/proguard-rules.pro
vendored
Normal file
21
examples/whisper.android/app/proguard-rules.pro
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
@ -0,0 +1,24 @@
|
||||
package com.whispercppdemo
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("com.whispercppdemo", appContext.packageName)
|
||||
}
|
||||
}
|
32
examples/whisper.android/app/src/main/AndroidManifest.xml
Normal file
32
examples/whisper.android/app/src/main/AndroidManifest.xml
Normal file
@ -0,0 +1,32 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools">
|
||||
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||
android:fullBackupContent="@xml/backup_rules"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.WhisperCppDemo"
|
||||
tools:targetApi="31">
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true"
|
||||
android:theme="@style/Theme.WhisperCppDemo">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
|
||||
<meta-data
|
||||
android:name="android.app.lib_name"
|
||||
android:value="" />
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -0,0 +1,22 @@
|
||||
package com.whispercppdemo
|
||||
|
||||
import android.os.Bundle
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.viewModels
|
||||
import com.whispercppdemo.ui.main.MainScreen
|
||||
import com.whispercppdemo.ui.main.MainScreenViewModel
|
||||
import com.whispercppdemo.ui.theme.WhisperCppDemoTheme
|
||||
|
||||
class MainActivity : ComponentActivity() {
|
||||
private val viewModel: MainScreenViewModel by viewModels { MainScreenViewModel.factory() }
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContent {
|
||||
WhisperCppDemoTheme {
|
||||
MainScreen(viewModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
package com.whispercppdemo.media
|
||||
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.File
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
|
||||
fun decodeWaveFile(file: File): FloatArray {
|
||||
val baos = ByteArrayOutputStream()
|
||||
file.inputStream().use { it.copyTo(baos) }
|
||||
val buffer = ByteBuffer.wrap(baos.toByteArray())
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
||||
buffer.position(44)
|
||||
val shortBuffer = buffer.asShortBuffer()
|
||||
val shortArray = ShortArray(shortBuffer.limit())
|
||||
shortBuffer.get(shortArray)
|
||||
return FloatArray(shortArray.size) { index ->
|
||||
(shortArray[index] / 32767.0f).coerceIn(-1f..1f)
|
||||
}
|
||||
}
|
||||
|
||||
fun encodeWaveFile(file: File, data: ShortArray) {
|
||||
file.outputStream().use {
|
||||
it.write(headerBytes(data.size * 2))
|
||||
val buffer = ByteBuffer.allocate(data.size * 2)
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN)
|
||||
buffer.asShortBuffer().put(data)
|
||||
val bytes = ByteArray(buffer.limit())
|
||||
buffer.get(bytes)
|
||||
it.write(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
private fun headerBytes(totalLength: Int): ByteArray {
|
||||
require(totalLength >= 44)
|
||||
ByteBuffer.allocate(44).apply {
|
||||
order(ByteOrder.LITTLE_ENDIAN)
|
||||
|
||||
put('R'.code.toByte())
|
||||
put('I'.code.toByte())
|
||||
put('F'.code.toByte())
|
||||
put('F'.code.toByte())
|
||||
|
||||
putInt(totalLength - 8)
|
||||
|
||||
put('W'.code.toByte())
|
||||
put('A'.code.toByte())
|
||||
put('V'.code.toByte())
|
||||
put('E'.code.toByte())
|
||||
|
||||
put('f'.code.toByte())
|
||||
put('m'.code.toByte())
|
||||
put('t'.code.toByte())
|
||||
put(' '.code.toByte())
|
||||
|
||||
putInt(16)
|
||||
putShort(1.toShort())
|
||||
putShort(1.toShort())
|
||||
putInt(16000)
|
||||
putInt(32000)
|
||||
putShort(2.toShort())
|
||||
putShort(16.toShort())
|
||||
|
||||
put('d'.code.toByte())
|
||||
put('a'.code.toByte())
|
||||
put('t'.code.toByte())
|
||||
put('a'.code.toByte())
|
||||
|
||||
putInt(totalLength - 44)
|
||||
position(0)
|
||||
}.also {
|
||||
val bytes = ByteArray(it.limit())
|
||||
it.get(bytes)
|
||||
return bytes
|
||||
}
|
||||
}
|
@ -0,0 +1,88 @@
|
||||
package com.whispercppdemo.recorder
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.MediaRecorder
|
||||
import com.whispercppdemo.media.encodeWaveFile
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.asCoroutineDispatcher
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
class Recorder {
|
||||
private val scope: CoroutineScope = CoroutineScope(
|
||||
Executors.newSingleThreadExecutor().asCoroutineDispatcher()
|
||||
)
|
||||
private var recorder: AudioRecordThread? = null
|
||||
|
||||
suspend fun startRecording(outputFile: File, onError: (Exception) -> Unit) = withContext(scope.coroutineContext) {
|
||||
recorder = AudioRecordThread(outputFile, onError)
|
||||
recorder?.start()
|
||||
}
|
||||
|
||||
suspend fun stopRecording() = withContext(scope.coroutineContext) {
|
||||
recorder?.stopRecording()
|
||||
@Suppress("BlockingMethodInNonBlockingContext")
|
||||
recorder?.join()
|
||||
recorder = null
|
||||
}
|
||||
}
|
||||
|
||||
private class AudioRecordThread(
|
||||
private val outputFile: File,
|
||||
private val onError: (Exception) -> Unit
|
||||
) :
|
||||
Thread("AudioRecorder") {
|
||||
private var quit = AtomicBoolean(false)
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun run() {
|
||||
try {
|
||||
val bufferSize = AudioRecord.getMinBufferSize(
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_16BIT
|
||||
) * 4
|
||||
val buffer = ShortArray(bufferSize / 2)
|
||||
|
||||
val audioRecord = AudioRecord(
|
||||
MediaRecorder.AudioSource.MIC,
|
||||
16000,
|
||||
AudioFormat.CHANNEL_IN_MONO,
|
||||
AudioFormat.ENCODING_PCM_16BIT,
|
||||
bufferSize
|
||||
)
|
||||
|
||||
try {
|
||||
audioRecord.startRecording()
|
||||
|
||||
val allData = mutableListOf<Short>()
|
||||
|
||||
while (!quit.get()) {
|
||||
val read = audioRecord.read(buffer, 0, buffer.size)
|
||||
if (read > 0) {
|
||||
for (i in 0 until read) {
|
||||
allData.add(buffer[i])
|
||||
}
|
||||
} else {
|
||||
throw java.lang.RuntimeException("audioRecord.read returned $read")
|
||||
}
|
||||
}
|
||||
|
||||
audioRecord.stop()
|
||||
encodeWaveFile(outputFile, allData.toShortArray())
|
||||
} finally {
|
||||
audioRecord.release()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
onError(e)
|
||||
}
|
||||
}
|
||||
|
||||
fun stopRecording() {
|
||||
quit.set(true)
|
||||
}
|
||||
}
|
@ -0,0 +1,99 @@
|
||||
package com.whispercppdemo.ui.main
|
||||
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.res.stringResource
|
||||
import androidx.compose.ui.unit.dp
|
||||
import com.google.accompanist.permissions.ExperimentalPermissionsApi
|
||||
import com.google.accompanist.permissions.isGranted
|
||||
import com.google.accompanist.permissions.rememberPermissionState
|
||||
import com.whispercppdemo.R
|
||||
|
||||
@Composable
|
||||
fun MainScreen(viewModel: MainScreenViewModel) {
|
||||
MainScreen(
|
||||
canTranscribe = viewModel.canTranscribe,
|
||||
isRecording = viewModel.isRecording,
|
||||
messageLog = viewModel.dataLog,
|
||||
onTranscribeSampleTapped = viewModel::transcribeSample,
|
||||
onRecordTapped = viewModel::toggleRecord
|
||||
)
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
private fun MainScreen(
|
||||
canTranscribe: Boolean,
|
||||
isRecording: Boolean,
|
||||
messageLog: String,
|
||||
onTranscribeSampleTapped: () -> Unit,
|
||||
onRecordTapped: () -> Unit
|
||||
) {
|
||||
Scaffold(
|
||||
topBar = {
|
||||
TopAppBar(
|
||||
title = { Text(stringResource(R.string.app_name)) }
|
||||
)
|
||||
},
|
||||
) { innerPadding ->
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(innerPadding)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Row(horizontalArrangement = Arrangement.SpaceBetween) {
|
||||
TranscribeSampleButton(enabled = canTranscribe, onClick = onTranscribeSampleTapped)
|
||||
RecordButton(
|
||||
enabled = canTranscribe,
|
||||
isRecording = isRecording,
|
||||
onClick = onRecordTapped
|
||||
)
|
||||
}
|
||||
MessageLog(messageLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun MessageLog(log: String) {
|
||||
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun TranscribeSampleButton(enabled: Boolean, onClick: () -> Unit) {
|
||||
Button(onClick = onClick, enabled = enabled) {
|
||||
Text("Transcribe sample")
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(ExperimentalPermissionsApi::class)
|
||||
@Composable
|
||||
private fun RecordButton(enabled: Boolean, isRecording: Boolean, onClick: () -> Unit) {
|
||||
val micPermissionState = rememberPermissionState(
|
||||
permission = android.Manifest.permission.RECORD_AUDIO,
|
||||
onPermissionResult = { granted ->
|
||||
if (granted) {
|
||||
onClick()
|
||||
}
|
||||
}
|
||||
)
|
||||
Button(onClick = {
|
||||
if (micPermissionState.status.isGranted) {
|
||||
onClick()
|
||||
} else {
|
||||
micPermissionState.launchPermissionRequest()
|
||||
}
|
||||
}, enabled = enabled) {
|
||||
Text(
|
||||
if (isRecording) {
|
||||
"Stop recording"
|
||||
} else {
|
||||
"Start recording"
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
@ -0,0 +1,193 @@
|
||||
package com.whispercppdemo.ui.main
|
||||
|
||||
import android.app.Application
|
||||
import android.content.Context
|
||||
import android.media.MediaPlayer
|
||||
import android.util.Log
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.core.net.toUri
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.ViewModelProvider
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import androidx.lifecycle.viewmodel.initializer
|
||||
import androidx.lifecycle.viewmodel.viewModelFactory
|
||||
import com.whispercppdemo.media.decodeWaveFile
|
||||
import com.whispercppdemo.recorder.Recorder
|
||||
import com.whispercppdemo.whisper.WhisperContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
|
||||
private const val LOG_TAG = "MainScreenViewModel"
|
||||
|
||||
class MainScreenViewModel(private val application: Application) : ViewModel() {
|
||||
var canTranscribe by mutableStateOf(false)
|
||||
private set
|
||||
var dataLog by mutableStateOf("")
|
||||
private set
|
||||
var isRecording by mutableStateOf(false)
|
||||
private set
|
||||
|
||||
private val modelsPath = File(application.filesDir, "models")
|
||||
private val samplesPath = File(application.filesDir, "samples")
|
||||
private var recorder: Recorder = Recorder()
|
||||
private var whisperContext: WhisperContext? = null
|
||||
private var mediaPlayer: MediaPlayer? = null
|
||||
private var recordedFile: File? = null
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
loadData()
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun loadData() {
|
||||
printMessage("Loading data...\n")
|
||||
try {
|
||||
copyAssets()
|
||||
loadBaseModel()
|
||||
canTranscribe = true
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun printMessage(msg: String) = withContext(Dispatchers.Main) {
|
||||
dataLog += msg
|
||||
}
|
||||
|
||||
private suspend fun copyAssets() = withContext(Dispatchers.IO) {
|
||||
modelsPath.mkdirs()
|
||||
samplesPath.mkdirs()
|
||||
application.copyData("models", modelsPath, ::printMessage)
|
||||
application.copyData("samples", samplesPath, ::printMessage)
|
||||
printMessage("All data copied to working directory.\n")
|
||||
}
|
||||
|
||||
private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
|
||||
printMessage("Loading model...\n")
|
||||
val firstModel = modelsPath.listFiles()!!.first()
|
||||
whisperContext = WhisperContext.createContext(firstModel.absolutePath)
|
||||
printMessage("Loaded model ${firstModel.name}.\n")
|
||||
}
|
||||
|
||||
fun transcribeSample() = viewModelScope.launch {
|
||||
transcribeAudio(getFirstSample())
|
||||
}
|
||||
|
||||
private suspend fun getFirstSample(): File = withContext(Dispatchers.IO) {
|
||||
samplesPath.listFiles()!!.first()
|
||||
}
|
||||
|
||||
private suspend fun readAudioSamples(file: File): FloatArray = withContext(Dispatchers.IO) {
|
||||
stopPlayback()
|
||||
startPlayback(file)
|
||||
return@withContext decodeWaveFile(file)
|
||||
}
|
||||
|
||||
private suspend fun stopPlayback() = withContext(Dispatchers.Main) {
|
||||
mediaPlayer?.stop()
|
||||
mediaPlayer?.release()
|
||||
mediaPlayer = null
|
||||
}
|
||||
|
||||
private suspend fun startPlayback(file: File) = withContext(Dispatchers.Main) {
|
||||
mediaPlayer = MediaPlayer.create(application, file.absolutePath.toUri())
|
||||
mediaPlayer?.start()
|
||||
}
|
||||
|
||||
private suspend fun transcribeAudio(file: File) {
|
||||
if (!canTranscribe) {
|
||||
return
|
||||
}
|
||||
|
||||
canTranscribe = false
|
||||
|
||||
try {
|
||||
printMessage("Reading wave samples...\n")
|
||||
val data = readAudioSamples(file)
|
||||
printMessage("Transcribing data...\n")
|
||||
val text = whisperContext?.transcribeData(data)
|
||||
printMessage("Done: $text\n")
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
}
|
||||
|
||||
canTranscribe = true
|
||||
}
|
||||
|
||||
fun toggleRecord() = viewModelScope.launch {
|
||||
try {
|
||||
if (isRecording) {
|
||||
recorder.stopRecording()
|
||||
isRecording = false
|
||||
recordedFile?.let { transcribeAudio(it) }
|
||||
} else {
|
||||
stopPlayback()
|
||||
val file = getTempFileForRecording()
|
||||
recorder.startRecording(file) { e ->
|
||||
viewModelScope.launch {
|
||||
withContext(Dispatchers.Main) {
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
isRecording = false
|
||||
}
|
||||
}
|
||||
}
|
||||
isRecording = true
|
||||
recordedFile = file
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(LOG_TAG, e)
|
||||
printMessage("${e.localizedMessage}\n")
|
||||
isRecording = false
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun getTempFileForRecording() = withContext(Dispatchers.IO) {
|
||||
File.createTempFile("recording", "wav")
|
||||
}
|
||||
|
||||
override fun onCleared() {
|
||||
runBlocking {
|
||||
whisperContext?.release()
|
||||
whisperContext = null
|
||||
stopPlayback()
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun factory() = viewModelFactory {
|
||||
initializer {
|
||||
val application =
|
||||
this[ViewModelProvider.AndroidViewModelFactory.APPLICATION_KEY] as Application
|
||||
MainScreenViewModel(application)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun Context.copyData(
|
||||
assetDirName: String,
|
||||
destDir: File,
|
||||
printMessage: suspend (String) -> Unit
|
||||
) = withContext(Dispatchers.IO) {
|
||||
assets.list(assetDirName)?.forEach { name ->
|
||||
val assetPath = "$assetDirName/$name"
|
||||
Log.v(LOG_TAG, "Processing $assetPath...")
|
||||
val destination = File(destDir, name)
|
||||
Log.v(LOG_TAG, "Copying $assetPath to $destination...")
|
||||
printMessage("Copying $name...\n")
|
||||
assets.open(assetPath).use { input ->
|
||||
destination.outputStream().use { output ->
|
||||
input.copyTo(output)
|
||||
}
|
||||
}
|
||||
Log.v(LOG_TAG, "Copied $assetPath to $destination")
|
||||
}
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package com.whispercppdemo.ui.theme
|
||||
|
||||
import androidx.compose.ui.graphics.Color
|
||||
|
||||
val Purple80 = Color(0xFFD0BCFF)
|
||||
val PurpleGrey80 = Color(0xFFCCC2DC)
|
||||
val Pink80 = Color(0xFFEFB8C8)
|
||||
|
||||
val Purple40 = Color(0xFF6650a4)
|
||||
val PurpleGrey40 = Color(0xFF625b71)
|
||||
val Pink40 = Color(0xFF7D5260)
|
@ -0,0 +1,68 @@
|
||||
package com.whispercppdemo.ui.theme
|
||||
|
||||
import android.app.Activity
|
||||
import android.os.Build
|
||||
import androidx.compose.foundation.isSystemInDarkTheme
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.darkColorScheme
|
||||
import androidx.compose.material3.dynamicDarkColorScheme
|
||||
import androidx.compose.material3.dynamicLightColorScheme
|
||||
import androidx.compose.material3.lightColorScheme
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.SideEffect
|
||||
import androidx.compose.ui.graphics.toArgb
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalView
|
||||
import androidx.core.view.ViewCompat
|
||||
|
||||
private val DarkColorScheme = darkColorScheme(
|
||||
primary = Purple80,
|
||||
secondary = PurpleGrey80,
|
||||
tertiary = Pink80
|
||||
)
|
||||
|
||||
private val LightColorScheme = lightColorScheme(
|
||||
primary = Purple40,
|
||||
secondary = PurpleGrey40,
|
||||
tertiary = Pink40
|
||||
|
||||
/* Other default colors to override
|
||||
background = Color(0xFFFFFBFE),
|
||||
surface = Color(0xFFFFFBFE),
|
||||
onPrimary = Color.White,
|
||||
onSecondary = Color.White,
|
||||
onTertiary = Color.White,
|
||||
onBackground = Color(0xFF1C1B1F),
|
||||
onSurface = Color(0xFF1C1B1F),
|
||||
*/
|
||||
)
|
||||
|
||||
@Composable
|
||||
fun WhisperCppDemoTheme(
|
||||
darkTheme: Boolean = isSystemInDarkTheme(),
|
||||
// Dynamic color is available on Android 12+
|
||||
dynamicColor: Boolean = true,
|
||||
content: @Composable () -> Unit
|
||||
) {
|
||||
val colorScheme = when {
|
||||
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
|
||||
val context = LocalContext.current
|
||||
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
|
||||
}
|
||||
darkTheme -> DarkColorScheme
|
||||
else -> LightColorScheme
|
||||
}
|
||||
val view = LocalView.current
|
||||
if (!view.isInEditMode) {
|
||||
SideEffect {
|
||||
(view.context as Activity).window.statusBarColor = colorScheme.primary.toArgb()
|
||||
ViewCompat.getWindowInsetsController(view)?.isAppearanceLightStatusBars = darkTheme
|
||||
}
|
||||
}
|
||||
|
||||
MaterialTheme(
|
||||
colorScheme = colorScheme,
|
||||
typography = Typography,
|
||||
content = content
|
||||
)
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package com.whispercppdemo.ui.theme
|
||||
|
||||
import androidx.compose.material3.Typography
|
||||
import androidx.compose.ui.text.TextStyle
|
||||
import androidx.compose.ui.text.font.FontFamily
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.sp
|
||||
|
||||
// Set of Material typography styles to start with
|
||||
val Typography = Typography(
|
||||
bodyLarge = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Normal,
|
||||
fontSize = 16.sp,
|
||||
lineHeight = 24.sp,
|
||||
letterSpacing = 0.5.sp
|
||||
)
|
||||
/* Other default text styles to override
|
||||
titleLarge = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Normal,
|
||||
fontSize = 22.sp,
|
||||
lineHeight = 28.sp,
|
||||
letterSpacing = 0.sp
|
||||
),
|
||||
labelSmall = TextStyle(
|
||||
fontFamily = FontFamily.Default,
|
||||
fontWeight = FontWeight.Medium,
|
||||
fontSize = 11.sp,
|
||||
lineHeight = 16.sp,
|
||||
letterSpacing = 0.5.sp
|
||||
)
|
||||
*/
|
||||
)
|
@ -0,0 +1,61 @@
|
||||
package com.whispercppdemo.whisper
|
||||
|
||||
import kotlinx.coroutines.*
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
class WhisperContext private constructor(private var ptr: Long) {
|
||||
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
|
||||
private val scope: CoroutineScope = CoroutineScope(
|
||||
Executors.newSingleThreadExecutor().asCoroutineDispatcher()
|
||||
)
|
||||
|
||||
suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
|
||||
require(ptr != 0L)
|
||||
WhisperLib.fullTranscribe(ptr, data)
|
||||
val textCount = WhisperLib.getTextSegmentCount(ptr)
|
||||
return@withContext buildString {
|
||||
for (i in 0 until textCount) {
|
||||
append(WhisperLib.getTextSegment(ptr, i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun release() = withContext(scope.coroutineContext) {
|
||||
if (ptr != 0L) {
|
||||
WhisperLib.freeContext(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
runBlocking {
|
||||
release()
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun createContext(filePath: String): WhisperContext {
|
||||
val ptr = WhisperLib.initContext(filePath)
|
||||
if (ptr == 0L) {
|
||||
throw java.lang.RuntimeException("Couldn't create context with path $filePath")
|
||||
}
|
||||
return WhisperContext(ptr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class WhisperLib {
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("whisper")
|
||||
}
|
||||
|
||||
// JNI methods
|
||||
external fun initContext(modelPath: String): Long
|
||||
external fun freeContext(contextPtr: Long)
|
||||
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
|
||||
external fun getTextSegmentCount(contextPtr: Long): Int
|
||||
external fun getTextSegment(contextPtr: Long, index: Int): String
|
||||
}
|
||||
}
|
||||
|
22
examples/whisper.android/app/src/main/jni/whisper/Android.mk
Normal file
22
examples/whisper.android/app/src/main/jni/whisper/Android.mk
Normal file
@ -0,0 +1,22 @@
|
||||
LOCAL_PATH := $(call my-dir)
|
||||
include $(CLEAR_VARS)
|
||||
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
|
||||
LOCAL_LDLIBS := -llog
|
||||
LOCAL_MODULE := libwhisper
|
||||
|
||||
# Make the final output library smaller by only keeping the symbols referenced from the app.
|
||||
ifneq ($(APP_OPTIM),debug)
|
||||
LOCAL_CFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
|
||||
LOCAL_CFLAGS += -ffunction-sections -fdata-sections
|
||||
LOCAL_LDFLAGS += -Wl,--gc-sections
|
||||
LOCAL_LDFLAGS += -Wl,--exclude-libs,ALL
|
||||
LOCAL_LDFLAGS += -flto
|
||||
endif
|
||||
|
||||
LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
|
||||
LOCAL_CPPFLAGS += -std=c++11
|
||||
LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
|
||||
$(WHISPER_LIB_DIR)/whisper.cpp \
|
||||
$(LOCAL_PATH)/jni.c
|
||||
|
||||
include $(BUILD_SHARED_LIBRARY)
|
@ -0,0 +1 @@
|
||||
APP_STL := c++_static
|
93
examples/whisper.android/app/src/main/jni/whisper/jni.c
Normal file
93
examples/whisper.android/app/src/main/jni/whisper/jni.c
Normal file
@ -0,0 +1,93 @@
|
||||
#include <jni.h>
|
||||
#include <android/log.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/sysinfo.h>
|
||||
#include "whisper.h"
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define TAG "JNI"
|
||||
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
|
||||
static inline int min(int a, int b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
static inline int max(int a, int b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
|
||||
JNIEnv *env, jobject thiz, jstring model_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
|
||||
context = whisper_init(model_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
whisper_free(context);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||
|
||||
// Leave 2 processors free (i.e. the high-efficiency cores).
|
||||
int max_threads = max(1, min(8, get_nprocs() - 2));
|
||||
LOGI("Selecting %d threads", max_threads);
|
||||
|
||||
// The below adapted from the Objective-C iOS sample
|
||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
params.print_realtime = true;
|
||||
params.print_progress = false;
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = false;
|
||||
params.language = "en";
|
||||
params.n_threads = max_threads;
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = false;
|
||||
|
||||
whisper_reset_timings(context);
|
||||
|
||||
LOGI("About to run whisper_full");
|
||||
if (whisper_full(context, params, audio_data_arr, audio_data_length) != 0) {
|
||||
LOGI("Failed to run the model");
|
||||
} else {
|
||||
whisper_print_timings(context);
|
||||
}
|
||||
(*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegmentCount(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
return whisper_full_n_segments(context);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_getTextSegment(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const char *text = whisper_full_get_segment_text(context, index);
|
||||
jstring string = (*env)->NewStringUTF(env, text);
|
||||
return string;
|
||||
}
|
@ -0,0 +1,170 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
@ -0,0 +1,30 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="85.84757"
|
||||
android:endY="92.4963"
|
||||
android:startX="42.9492"
|
||||
android:startY="49.59793"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||
android:strokeWidth="1"
|
||||
android:strokeColor="#00000000" />
|
||||
</vector>
|
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
10
examples/whisper.android/app/src/main/res/values/colors.xml
Normal file
10
examples/whisper.android/app/src/main/res/values/colors.xml
Normal file
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
@ -0,0 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">WhisperCppDemo</string>
|
||||
</resources>
|
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
|
||||
<style name="Theme.WhisperCppDemo" parent="android:Theme.Material.Light.NoActionBar" />
|
||||
</resources>
|
@ -0,0 +1,13 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Sample backup rules file; uncomment and customize as necessary.
|
||||
See https://developer.android.com/guide/topics/data/autobackup
|
||||
for details.
|
||||
Note: This file is ignored for devices older that API 31
|
||||
See https://developer.android.com/about/versions/12/backup-restore
|
||||
-->
|
||||
<full-backup-content>
|
||||
<!--
|
||||
<include domain="sharedpref" path="."/>
|
||||
<exclude domain="sharedpref" path="device.xml"/>
|
||||
-->
|
||||
</full-backup-content>
|
@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Sample data extraction rules file; uncomment and customize as necessary.
|
||||
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
|
||||
for details.
|
||||
-->
|
||||
<data-extraction-rules>
|
||||
<cloud-backup>
|
||||
<!-- TODO: Use <include> and <exclude> to control what is backed up.
|
||||
<include .../>
|
||||
<exclude .../>
|
||||
-->
|
||||
</cloud-backup>
|
||||
<!--
|
||||
<device-transfer>
|
||||
<include .../>
|
||||
<exclude .../>
|
||||
</device-transfer>
|
||||
-->
|
||||
</data-extraction-rules>
|
@ -0,0 +1,17 @@
|
||||
package com.whispercppdemo
|
||||
|
||||
import org.junit.Test
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
class ExampleUnitTest {
|
||||
@Test
|
||||
fun addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2)
|
||||
}
|
||||
}
|
6
examples/whisper.android/build.gradle
Normal file
6
examples/whisper.android/build.gradle
Normal file
@ -0,0 +1,6 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
plugins {
|
||||
id 'com.android.application' version '7.3.1' apply false
|
||||
id 'com.android.library' version '7.3.1' apply false
|
||||
id 'org.jetbrains.kotlin.android' version '1.7.10' apply false
|
||||
}
|
23
examples/whisper.android/gradle.properties
Normal file
23
examples/whisper.android/gradle.properties
Normal file
@ -0,0 +1,23 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app's APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Kotlin code style for this project: "official" or "obsolete":
|
||||
kotlin.code.style=official
|
||||
# Enables namespacing of each library's R class so that its R class includes only the
|
||||
# resources declared in the library itself and none from the library's dependencies,
|
||||
# thereby reducing the size of the R class for that library
|
||||
android.nonTransitiveRClass=true
|
BIN
examples/whisper.android/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
BIN
examples/whisper.android/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
Binary file not shown.
6
examples/whisper.android/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
6
examples/whisper.android/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
#Wed Dec 14 10:37:24 EST 2022
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
|
||||
distributionPath=wrapper/dists
|
||||
zipStorePath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
185
examples/whisper.android/gradlew
vendored
Executable file
185
examples/whisper.android/gradlew
vendored
Executable file
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
#
|
||||
# Copyright 2015 the original author or authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=`expr $i + 1`
|
||||
done
|
||||
case $i in
|
||||
0) set -- ;;
|
||||
1) set -- "$args0" ;;
|
||||
2) set -- "$args0" "$args1" ;;
|
||||
3) set -- "$args0" "$args1" "$args2" ;;
|
||||
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=`save "$@"`
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
exec "$JAVACMD" "$@"
|
89
examples/whisper.android/gradlew.bat
vendored
Normal file
89
examples/whisper.android/gradlew.bat
vendored
Normal file
@ -0,0 +1,89 @@
|
||||
@rem
|
||||
@rem Copyright 2015 the original author or authors.
|
||||
@rem
|
||||
@rem Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@rem you may not use this file except in compliance with the License.
|
||||
@rem You may obtain a copy of the License at
|
||||
@rem
|
||||
@rem https://www.apache.org/licenses/LICENSE-2.0
|
||||
@rem
|
||||
@rem Unless required by applicable law or agreed to in writing, software
|
||||
@rem distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@rem See the License for the specific language governing permissions and
|
||||
@rem limitations under the License.
|
||||
@rem
|
||||
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
|
||||
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
10
examples/whisper.android/local.properties
Normal file
10
examples/whisper.android/local.properties
Normal file
@ -0,0 +1,10 @@
|
||||
## This file is automatically generated by Android Studio.
|
||||
# Do not modify this file -- YOUR CHANGES WILL BE ERASED!
|
||||
#
|
||||
# This file should *NOT* be checked into Version Control Systems,
|
||||
# as it contains information specific to your local configuration.
|
||||
#
|
||||
# Location of the SDK. This is only used by Gradle.
|
||||
# For customization when using a Version Control System, please read the
|
||||
# header note.
|
||||
sdk.dir=/Users/kevin/Library/Android/sdk
|
16
examples/whisper.android/settings.gradle
Normal file
16
examples/whisper.android/settings.gradle
Normal file
@ -0,0 +1,16 @@
|
||||
pluginManagement {
|
||||
repositories {
|
||||
gradlePluginPortal()
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
}
|
||||
dependencyResolutionManagement {
|
||||
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
}
|
||||
rootProject.name = "WhisperCppDemo"
|
||||
include ':app'
|
176
ggml.c
176
ggml.c
@ -14,6 +14,12 @@
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
// if C99 - static_assert is noop
|
||||
// ref: https://stackoverflow.com/a/53923785/4039976
|
||||
#ifndef static_assert
|
||||
#define static_assert(cond, msg) struct global_scope_noop_trick
|
||||
#endif
|
||||
|
||||
#if defined _MSC_VER || defined(__MINGW32__)
|
||||
|
||||
#if !defined(__MINGW32__)
|
||||
@ -135,9 +141,6 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
// FP16 <-> FP32
|
||||
// ref: https://github.com/Maratyszcza/FP16
|
||||
|
||||
#ifdef __F16C__
|
||||
float ggml_fp16_to_fp32(ggml_fp16_t h) {
|
||||
return _cvtsh_ss(h);
|
||||
@ -151,6 +154,9 @@ ggml_fp16_t ggml_fp32_to_fp16(float f) {
|
||||
|
||||
#else
|
||||
|
||||
// FP16 <-> FP32
|
||||
// ref: https://github.com/Maratyszcza/FP16
|
||||
|
||||
static inline float fp32_from_bits(uint32_t w) {
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
@ -434,10 +440,10 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
||||
y2 = _mm256_loadu_ps(y + i + 16);
|
||||
y3 = _mm256_loadu_ps(y + i + 24);
|
||||
|
||||
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
|
||||
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
|
||||
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
|
||||
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
|
||||
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
|
||||
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
|
||||
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
|
||||
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
|
||||
}
|
||||
|
||||
sum0 = _mm256_add_ps(sum0, sum1);
|
||||
@ -675,10 +681,10 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
||||
y2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
|
||||
y3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
|
||||
|
||||
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
|
||||
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
|
||||
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
|
||||
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
|
||||
sum0 = _mm256_add_ps(_mm256_mul_ps(x0, y0), sum0);
|
||||
sum1 = _mm256_add_ps(_mm256_mul_ps(x1, y1), sum1);
|
||||
sum2 = _mm256_add_ps(_mm256_mul_ps(x2, y2), sum2);
|
||||
sum3 = _mm256_add_ps(_mm256_mul_ps(x3, y3), sum3);
|
||||
}
|
||||
|
||||
const __m256 sum01 = _mm256_add_ps(sum0, sum1);
|
||||
@ -844,10 +850,10 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
||||
y2 = _mm256_loadu_ps(y + i + 16);
|
||||
y3 = _mm256_loadu_ps(y + i + 24);
|
||||
|
||||
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
|
||||
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
|
||||
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
|
||||
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
|
||||
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v4), y0);
|
||||
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v4), y1);
|
||||
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v4), y2);
|
||||
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v4), y3);
|
||||
|
||||
_mm256_storeu_ps(y + i + 0, y0);
|
||||
_mm256_storeu_ps(y + i + 8, y1);
|
||||
@ -1041,10 +1047,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
|
||||
x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
|
||||
x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
|
||||
|
||||
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
|
||||
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
|
||||
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
|
||||
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
|
||||
y0 = _mm256_add_ps(_mm256_mul_ps(x0, v8), y0);
|
||||
y1 = _mm256_add_ps(_mm256_mul_ps(x1, v8), y1);
|
||||
y2 = _mm256_add_ps(_mm256_mul_ps(x2, v8), y2);
|
||||
y3 = _mm256_add_ps(_mm256_mul_ps(x3, v8), y3);
|
||||
|
||||
_mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0));
|
||||
_mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0));
|
||||
@ -1112,7 +1118,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
||||
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
||||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
#if defined(__AVX__) || defined(__AVX2__)
|
||||
// AVX 256-bit
|
||||
const int n32 = (n & ~31);
|
||||
|
||||
const __m256 v4 = _mm256_set1_ps(v);
|
||||
|
||||
__m256 y0, y1, y2, y3;
|
||||
|
||||
for (int i = 0; i < n32; i += 32) {
|
||||
y0 = _mm256_loadu_ps(y + i + 0);
|
||||
y1 = _mm256_loadu_ps(y + i + 8);
|
||||
y2 = _mm256_loadu_ps(y + i + 16);
|
||||
y3 = _mm256_loadu_ps(y + i + 24);
|
||||
|
||||
y0 = _mm256_mul_ps(y0, v4);
|
||||
y1 = _mm256_mul_ps(y1, v4);
|
||||
y2 = _mm256_mul_ps(y2, v4);
|
||||
y3 = _mm256_mul_ps(y3, v4);
|
||||
|
||||
_mm256_storeu_ps(y + i + 0, y0);
|
||||
_mm256_storeu_ps(y + i + 8, y1);
|
||||
_mm256_storeu_ps(y + i + 16, y2);
|
||||
_mm256_storeu_ps(y + i + 24, y3);
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = n32; i < n; ++i) {
|
||||
y[i] *= v;
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] *= v;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s); }
|
||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
|
||||
@ -3172,22 +3216,96 @@ void ggml_compute_forward_dup_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
//const int ne00 = src0->ne[0];
|
||||
//const int ne01 = src0->ne[1];
|
||||
//const int ne02 = src0->ne[2];
|
||||
//const int ne03 = src0->ne[3];
|
||||
const int ne00 = src0->ne[0];
|
||||
const int ne01 = src0->ne[1];
|
||||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
//const size_t nb00 = src0->nb[0];
|
||||
//const size_t nb01 = src0->nb[1];
|
||||
//const size_t nb02 = src0->nb[2];
|
||||
//const size_t nb03 = src0->nb[3];
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
const size_t nb02 = src0->nb[2];
|
||||
const size_t nb03 = src0->nb[3];
|
||||
|
||||
if (ggml_is_contiguous(src0) && src0->type == dst->type) {
|
||||
memcpy(dst->data, src0->data, ggml_nelements(dst) * GGML_TYPE_SIZE[src0->type]);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
if (src0->nb[0] == sizeof(ggml_fp16_t)) {
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
int id = 0;
|
||||
const size_t rs = ne00*nb00;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
char * dst_ptr = (char *) dst->data + id*rs;
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, rs);
|
||||
|
||||
id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F32) {
|
||||
int id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
}
|
||||
} else {
|
||||
//printf("%s: this is not optimal - fix me\n", __func__);
|
||||
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
int id = 0;
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
|
||||
id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_F16) {
|
||||
int id = 0;
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
dst_ptr[id] = *src0_ptr;
|
||||
id++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_dup_f32(
|
||||
|
46
ggml.h
46
ggml.h
@ -681,34 +681,32 @@ struct ggml_opt_params {
|
||||
bool print_forward_graph;
|
||||
bool print_backward_graph;
|
||||
|
||||
union {
|
||||
// ADAM parameters
|
||||
struct {
|
||||
int n_iter;
|
||||
// ADAM parameters
|
||||
struct {
|
||||
int n_iter;
|
||||
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
float eps; // epsilon for numerical stability
|
||||
float eps_f; // epsilon for convergence test
|
||||
float eps_g; // epsilon for convergence test
|
||||
} adam;
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
float eps; // epsilon for numerical stability
|
||||
float eps_f; // epsilon for convergence test
|
||||
float eps_g; // epsilon for convergence test
|
||||
} adam;
|
||||
|
||||
// LBFGS parameters
|
||||
struct {
|
||||
int m; // number of corrections to approximate the inv. Hessian
|
||||
int n_iter;
|
||||
int max_linesearch;
|
||||
// LBFGS parameters
|
||||
struct {
|
||||
int m; // number of corrections to approximate the inv. Hessian
|
||||
int n_iter;
|
||||
int max_linesearch;
|
||||
|
||||
float eps; // convergence tolerance
|
||||
float ftol; // line search tolerance
|
||||
float wolfe;
|
||||
float min_step;
|
||||
float max_step;
|
||||
float eps; // convergence tolerance
|
||||
float ftol; // line search tolerance
|
||||
float wolfe;
|
||||
float min_step;
|
||||
float max_step;
|
||||
|
||||
enum ggml_linesearch linesearch;
|
||||
} lbfgs;
|
||||
};
|
||||
enum ggml_linesearch linesearch;
|
||||
} lbfgs;
|
||||
};
|
||||
|
||||
struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
|
||||
|
@ -40,7 +40,7 @@ if exist "ggml-%model%.bin" (
|
||||
goto :eof
|
||||
)
|
||||
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://ggml.ggerganov.com/ggml-model-whisper-%model%.bin -OutFile ggml-%model%.bin"
|
||||
PowerShell -NoProfile -ExecutionPolicy Bypass -Command "Invoke-WebRequest -Uri https://huggingface.co/datasets/ggerganov/whisper.cpp/raw/main/ggml-%model%.bin -OutFile ggml-%model%.bin"
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo Failed to download ggml model %model%
|
||||
|
244
whisper.cpp
244
whisper.cpp
@ -14,6 +14,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
#define USE_FLASH_ATTN
|
||||
//#define USE_FLASH_FF
|
||||
@ -549,13 +550,20 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
|
||||
//}
|
||||
|
||||
std::string word;
|
||||
std::vector<char> tmp;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
read_safe(fin, len);
|
||||
|
||||
std::vector<char> tmp(len); // create a buffer
|
||||
fin.read( &tmp[0], tmp.size() ); // read to buffer
|
||||
word.assign(&tmp[0], tmp.size());
|
||||
if (len > 0) {
|
||||
tmp.resize(len);
|
||||
fin.read(&tmp[0], tmp.size()); // read to buffer
|
||||
word.assign(&tmp[0], tmp.size());
|
||||
} else {
|
||||
// seems like we have an empty-string token in multi-language models (i = 50256)
|
||||
//fprintf(stderr, "%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
|
||||
word = "";
|
||||
}
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
@ -1097,7 +1105,7 @@ static bool whisper_encode(
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = wctx.buf_compute.size();
|
||||
params.mem_buffer = wctx.buf_compute.data();
|
||||
params.mem_buffer = wctx.buf_compute.data();
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
@ -2154,6 +2162,71 @@ static bool log_mel_spectrogram(
|
||||
return true;
|
||||
}
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<whisper_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
@ -2284,8 +2357,38 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
|
||||
return res;
|
||||
}
|
||||
|
||||
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
||||
const auto res = tokenize(ctx->vocab, text);
|
||||
|
||||
if (res.size() > n_max_tokens) {
|
||||
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < res.size(); i++) {
|
||||
tokens[i] = res[i];
|
||||
}
|
||||
|
||||
return res.size();
|
||||
}
|
||||
|
||||
int whisper_lang_max_id() {
|
||||
auto max_id = 0;
|
||||
for (const auto & kv : g_lang) {
|
||||
max_id = std::max(max_id, kv.second.first);
|
||||
}
|
||||
|
||||
return max_id;
|
||||
}
|
||||
|
||||
int whisper_lang_id(const char * lang) {
|
||||
if (!g_lang.count(lang)) {
|
||||
for (const auto & kv : g_lang) {
|
||||
if (kv.second.second == lang) {
|
||||
return kv.second.first;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
||||
return -1;
|
||||
}
|
||||
@ -2293,6 +2396,86 @@ int whisper_lang_id(const char * lang) {
|
||||
return g_lang.at(lang).first;
|
||||
}
|
||||
|
||||
const char * whisper_lang_str(int id) {
|
||||
for (const auto & kv : g_lang) {
|
||||
if (kv.second.first == id) {
|
||||
return kv.first.c_str();
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: unknown language id %d\n", __func__, id);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int whisper_lang_auto_detect(
|
||||
struct whisper_context * ctx,
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs) {
|
||||
const int seek = offset_ms/10;
|
||||
|
||||
if (seek < 0) {
|
||||
fprintf(stderr, "%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (seek >= ctx->mel.n_len) {
|
||||
fprintf(stderr, "%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, ctx->mel.n_len*10);
|
||||
return -2;
|
||||
}
|
||||
|
||||
// run the encoder
|
||||
if (whisper_encode(ctx, seek, n_threads) != 0) {
|
||||
fprintf(stderr, "%s: failed to encode\n", __func__);
|
||||
return -6;
|
||||
}
|
||||
|
||||
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
|
||||
|
||||
if (whisper_decode(ctx, prompt.data(), prompt.size(), 0, n_threads) != 0) {
|
||||
fprintf(stderr, "%s: failed to decode\n", __func__);
|
||||
return -7;
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, int>> probs_id;
|
||||
for (const auto kv : g_lang) {
|
||||
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
|
||||
probs_id.push_back({ ctx->probs[token_lang], kv.second.first });
|
||||
}
|
||||
|
||||
// sort descending
|
||||
{
|
||||
using pair_type = decltype(probs_id)::value_type;
|
||||
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
}
|
||||
|
||||
// softmax
|
||||
{
|
||||
float sum = 0;
|
||||
for (const auto & kv : probs_id) {
|
||||
sum += exp(kv.first);
|
||||
}
|
||||
|
||||
for (auto & kv : probs_id) {
|
||||
kv.first = exp(kv.first) / sum;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
for (int i = 0; i < probs_id.size(); i++) {
|
||||
if (lang_probs) {
|
||||
lang_probs[probs_id[i].second] = probs_id[i].first;
|
||||
}
|
||||
|
||||
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
|
||||
}
|
||||
}
|
||||
|
||||
return probs_id[0].second;
|
||||
}
|
||||
|
||||
int whisper_n_len(struct whisper_context * ctx) {
|
||||
return ctx->mel.n_len;
|
||||
}
|
||||
@ -2341,6 +2524,10 @@ whisper_token whisper_token_beg(struct whisper_context * ctx) {
|
||||
return ctx->vocab.token_beg;
|
||||
}
|
||||
|
||||
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
|
||||
return whisper_token_sot(ctx) + 1 + lang_id;
|
||||
}
|
||||
|
||||
whisper_token whisper_token_translate(void) {
|
||||
return whisper_vocab::token_translate;
|
||||
}
|
||||
@ -2573,10 +2760,25 @@ int whisper_full(
|
||||
} else {
|
||||
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
|
||||
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
|
||||
return -1;
|
||||
return -2;
|
||||
}
|
||||
}
|
||||
|
||||
// auto-detect language if not specified
|
||||
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
||||
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
||||
|
||||
const auto lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, probs.data());
|
||||
if (lang_id < 0) {
|
||||
fprintf(stderr, "%s: failed to auto-detect language\n", __func__);
|
||||
return -3;
|
||||
}
|
||||
|
||||
params.language = whisper_lang_str(lang_id);
|
||||
|
||||
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
||||
}
|
||||
|
||||
if (params.token_timestamps) {
|
||||
ctx->t_beg = 0;
|
||||
ctx->t_last = 0;
|
||||
@ -2615,7 +2817,8 @@ int whisper_full(
|
||||
// these tokens determine the task that will be performed
|
||||
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
||||
if (whisper_is_multilingual(ctx)) {
|
||||
prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
|
||||
const int lang_id = whisper_lang_id(params.language);
|
||||
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
|
||||
if (params.translate) {
|
||||
prompt_init.push_back(whisper_token_translate());
|
||||
} else {
|
||||
@ -2643,10 +2846,17 @@ int whisper_full(
|
||||
}
|
||||
}
|
||||
|
||||
// of only 1 second left, then stop
|
||||
if (seek + 100 >= seek_end) {
|
||||
break;
|
||||
}
|
||||
|
||||
// if there is a very short audio segment left to process, we remove any past prompt since it tends
|
||||
// to confuse the decoder and often make it repeat or hallucinate stuff
|
||||
if (seek > seek_start && seek + 500 >= seek_end) {
|
||||
prompt_past.clear();
|
||||
}
|
||||
|
||||
if (params.encoder_begin_callback) {
|
||||
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
|
||||
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
|
||||
@ -2657,7 +2867,7 @@ int whisper_full(
|
||||
// encode audio features starting at offset seek
|
||||
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
|
||||
fprintf(stderr, "%s: failed to encode\n", __func__);
|
||||
return 7;
|
||||
return -4;
|
||||
}
|
||||
|
||||
int n_past = 0;
|
||||
@ -2695,7 +2905,7 @@ int whisper_full(
|
||||
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
|
||||
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
|
||||
fprintf(stderr, "%s: failed to decode\n", __func__);
|
||||
return 8;
|
||||
return -5;
|
||||
}
|
||||
|
||||
n_past += prompt.size();
|
||||
@ -2731,13 +2941,13 @@ int whisper_full(
|
||||
|
||||
//{
|
||||
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
|
||||
// printf("%s: %10s %6d %6.3f '%s'\n", __func__, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
||||
// printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str());
|
||||
//}
|
||||
|
||||
// end of segment
|
||||
if (token.id == whisper_token_eot(ctx) || // end of text token
|
||||
(params.max_tokens > 0 && i > params.max_tokens) || // max tokens per segment reached
|
||||
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
||||
if (token.id == whisper_token_eot(ctx) || // end of text token
|
||||
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
|
||||
(has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached
|
||||
) {
|
||||
if (result_len == 0) {
|
||||
if (seek + seek_delta + 100 >= seek_end) {
|
||||
@ -2773,8 +2983,14 @@ int whisper_full(
|
||||
}
|
||||
|
||||
if (failed) {
|
||||
fprintf(stderr, "\n%s: failed to generate timestamp token - using fallback strategy\n\n", __func__);
|
||||
seek += 100;
|
||||
// when we fail to sample timestamp token, retry by clearing the past prompt
|
||||
// if it fails again, then we advance the window by 1 second
|
||||
if (prompt_past.size() > 0) {
|
||||
prompt_past.clear();
|
||||
} else {
|
||||
fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__);
|
||||
seek += 100;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
34
whisper.h
34
whisper.h
@ -139,9 +139,41 @@ extern "C" {
|
||||
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns -1 on failure
|
||||
// TODO: not sure if correct
|
||||
WHISPER_API int whisper_tokenize(
|
||||
struct whisper_context * ctx,
|
||||
const char * text,
|
||||
whisper_token * tokens,
|
||||
int n_max_tokens);
|
||||
|
||||
// Largest language id (i.e. number of available languages - 1)
|
||||
WHISPER_API int whisper_lang_max_id();
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
// Examples:
|
||||
// "de" -> 2
|
||||
// "german" -> 2
|
||||
WHISPER_API int whisper_lang_id(const char * lang);
|
||||
|
||||
// Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found
|
||||
WHISPER_API const char * whisper_lang_str(int id);
|
||||
|
||||
// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
|
||||
// Returns the top language id or negative on failure
|
||||
// If not null, fills the lang_probs array with the probabilities of all languages
|
||||
// The array must be whispe_lang_max_id() + 1 in size
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
|
||||
WHISPER_API int whisper_lang_auto_detect(
|
||||
struct whisper_context * ctx,
|
||||
int offset_ms,
|
||||
int n_threads,
|
||||
float * lang_probs);
|
||||
|
||||
WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length
|
||||
WHISPER_API int whisper_n_vocab (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx);
|
||||
@ -160,6 +192,7 @@ extern "C" {
|
||||
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
|
||||
|
||||
// Task tokens
|
||||
WHISPER_API whisper_token whisper_token_translate (void);
|
||||
@ -225,6 +258,7 @@ extern "C" {
|
||||
const whisper_token * prompt_tokens;
|
||||
int prompt_n_tokens;
|
||||
|
||||
// for auto-detection, set to nullptr, "" or "auto"
|
||||
const char * language;
|
||||
|
||||
struct {
|
||||
|
Reference in New Issue
Block a user