Compare commits
23 Commits
quantize-e
...
parallel-s
Author | SHA1 | Date | |
---|---|---|---|
a2f3b82db3 | |||
76c8b5235b | |||
d029784fb0 | |||
40c66036b6 | |||
fc8565d0e2 | |||
b618229340 | |||
b27726da93 | |||
0867e696a7 | |||
66bb2e9401 | |||
3bfc43e3e3 | |||
f53e1388f5 | |||
933c5bef97 | |||
c99e290a7f | |||
728e1785f0 | |||
d6dad64fbf | |||
a54d8c9dec | |||
0ab5025316 | |||
3f5c1b7ee0 | |||
12030358ee | |||
dcf9511dbb | |||
3dfbe64911 | |||
7e01486b61 | |||
659757329d |
26
.github/workflows/build.yml
vendored
@ -396,32 +396,6 @@ jobs:
|
||||
cd examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
||||
|
||||
android_java:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set up JDK 11
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'temurin'
|
||||
cache: gradle
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
with:
|
||||
api-level: 30
|
||||
build-tools-version: 30.0.3
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android.java
|
||||
chmod +x ./gradlew
|
||||
./gradlew assembleRelease
|
||||
|
||||
java:
|
||||
needs: [ 'windows' ]
|
||||
runs-on: windows-latest
|
||||
|
4
.gitignore
vendored
@ -54,7 +54,3 @@ bindings/java/.idea/
|
||||
.idea/
|
||||
|
||||
benchmark_results.csv
|
||||
cmake-build-debug/
|
||||
.cxx/
|
||||
.gradle/
|
||||
local.properties
|
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
project(whisper.cpp VERSION 1.5.0)
|
||||
project(whisper.cpp VERSION 1.4.3)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
8
Makefile
@ -362,8 +362,8 @@ quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
@ -418,9 +418,9 @@ samples:
|
||||
.PHONY: medium
|
||||
.PHONY: large-v1
|
||||
.PHONY: large-v2
|
||||
.PHONY: large-v3
|
||||
.PHONY: large
|
||||
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3: main
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large: main
|
||||
bash ./models/download-ggml-model.sh $@
|
||||
@echo ""
|
||||
@echo "==============================================="
|
||||
|
34
README.md
@ -6,7 +6,7 @@
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.5.0](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
Beta: [v1.4.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.3) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
@ -16,10 +16,12 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Support for CPU-only inference
|
||||
- [Efficient GPU support for NVIDIA](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
||||
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
|
||||
@ -34,8 +36,10 @@ Supported platforms:
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
|
||||
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
|
||||
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
|
||||
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
|
||||
|
||||
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
||||
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
||||
@ -231,18 +235,18 @@ make medium.en
|
||||
make medium
|
||||
make large-v1
|
||||
make large-v2
|
||||
make large-v3
|
||||
make large
|
||||
```
|
||||
|
||||
## Memory usage
|
||||
|
||||
| Model | Disk | Mem |
|
||||
| --- | --- | --- |
|
||||
| tiny | 75 MiB | ~273 MB |
|
||||
| base | 142 MiB | ~388 MB |
|
||||
| small | 466 MiB | ~852 MB |
|
||||
| medium | 1.5 GiB | ~2.1 GB |
|
||||
| large | 2.9 GiB | ~3.9 GB |
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~3.3 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
|
||||
|
||||
## Quantization
|
||||
|
||||
@ -396,12 +400,12 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
|
||||
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
||||
cached for the next run.
|
||||
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
||||
|
||||
## NVIDIA GPU support
|
||||
## NVIDIA GPU support via cuBLAS
|
||||
|
||||
With NVIDIA cards the processing of the models is done efficiently on the GPU via cuBLAS and custom CUDA kernels.
|
||||
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
|
||||
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
Now build `whisper.cpp` with cuBLAS support:
|
||||
|
@ -24,7 +24,7 @@ const (
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3"}
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large"}
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -9,7 +9,6 @@ archivesBaseName = 'whispercpp'
|
||||
group = 'io.github.ggerganov'
|
||||
version = '1.4.0'
|
||||
|
||||
|
||||
sourceCompatibility = 1.8
|
||||
targetCompatibility = 1.8
|
||||
|
||||
|
@ -2,7 +2,6 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -10,8 +9,6 @@ import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
|
||||
@ -163,28 +160,6 @@ public class WhisperCpp implements AutoCloseable {
|
||||
|
||||
return str.toString().trim();
|
||||
}
|
||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
int nSegments = lib.whisper_full_n_segments(ctx);
|
||||
List<WhisperSegment> segments= new ArrayList<>(nSegments);
|
||||
|
||||
|
||||
for (int i = 0; i < nSegments; i++) {
|
||||
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
|
||||
String text = lib.whisper_full_get_segment_text(ctx, i);
|
||||
long t1 = lib.whisper_full_get_segment_t1(ctx, i);
|
||||
segments.add(new WhisperSegment(t0,t1,text));
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
|
||||
// public int getTextSegmentCount(Pointer ctx) {
|
||||
// return lib.whisper_full_n_segments(ctx);
|
||||
|
@ -1,47 +0,0 @@
|
||||
package io.github.ggerganov.whispercpp.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "[" + start + " --> " + end + "]:" + sentence;
|
||||
}
|
||||
}
|
@ -58,9 +58,6 @@ public class WhisperFullParams extends Structure {
|
||||
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
||||
}
|
||||
|
||||
/** Generate timestamps or not? */
|
||||
public CBool no_timestamps;
|
||||
|
||||
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||
public CBool single_segment;
|
||||
|
||||
@ -307,16 +304,10 @@ public class WhisperFullParams extends Structure {
|
||||
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
/** Grammar stuff */
|
||||
public Pointer grammar_rules;
|
||||
public long n_grammar_rules;
|
||||
public long i_start_rule;
|
||||
public float grammar_penalty;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||
"no_context", "single_segment", "no_timestamps",
|
||||
"no_context", "single_segment",
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
@ -325,7 +316,6 @@ public class WhisperFullParams extends Structure {
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||
"logits_filter_callback", "logits_filter_callback_user_data",
|
||||
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
|
||||
"logits_filter_callback", "logits_filter_callback_user_data");
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.CBool;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -12,7 +11,6 @@ import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.util.List;
|
||||
|
||||
class WhisperCppTest {
|
||||
private static WhisperCpp whisper = new WhisperCpp();
|
||||
@ -22,12 +20,11 @@ class WhisperCppTest {
|
||||
static void init() throws FileNotFoundException {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
//String modelName = "../../models/ggml-tiny.bin";
|
||||
String modelName = "../../models/ggml-tiny.en.bin";
|
||||
try {
|
||||
whisper.initContext(modelName);
|
||||
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
modelInitialised = true;
|
||||
} catch (FileNotFoundException ex) {
|
||||
System.out.println("Model " + modelName + " not found");
|
||||
@ -45,7 +42,7 @@ class WhisperCppTest {
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertFalse(params.translate);
|
||||
assertEquals(0.01f, params.thold_pt);
|
||||
assertEquals(5, params.beam_search.beam_size);
|
||||
assertEquals(2, params.beam_search.beam_size);
|
||||
assertEquals(-1.0f, params.beam_search.patience);
|
||||
}
|
||||
|
||||
@ -58,7 +55,7 @@ class WhisperCppTest {
|
||||
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
|
||||
assertNotEquals(0, params.n_threads);
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertEquals(5, params.greedy.best_of);
|
||||
assertEquals(2, params.greedy.best_of);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -75,11 +72,11 @@ class WhisperCppTest {
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
// params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
|
||||
try {
|
||||
@ -102,43 +99,4 @@ class WhisperCppTest {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFullTranscribeWithTime() throws Exception {
|
||||
if (!modelInitialised) {
|
||||
System.out.println("Model not initialised, skipping test");
|
||||
return;
|
||||
}
|
||||
|
||||
// Given
|
||||
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
|
||||
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
|
||||
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
try {
|
||||
audioInputStream.read(b);
|
||||
|
||||
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
|
||||
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
|
||||
floats[j] = intSample / 32767.0f;
|
||||
}
|
||||
|
||||
List<WhisperSegment> segments = whisper.fullTranscribeWithTime(params, floats);
|
||||
assertTrue(segments.size() > 0, "The size of segments should be greater than 0");
|
||||
for (WhisperSegment segment : segments) {
|
||||
System.out.println(segment);
|
||||
}
|
||||
} finally {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.5.0",
|
||||
"version": "1.4.3",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
@ -23,7 +23,6 @@ add_library(${TARGET} STATIC
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
grammar-parser.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -104,30 +104,20 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// batched decoding
|
||||
for (int i = 0; i < 64; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// prompt processing
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
@ -22,11 +21,6 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
bool file_exists(const std::string & fname) {
|
||||
std::ifstream f(fname.c_str());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -36,12 +30,8 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
@ -55,8 +45,6 @@ struct whisper_params {
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string context;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -87,9 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -124,30 +109,16 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string & grammar_rule,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms) {
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
logprob_min = 0.0f;
|
||||
logprob_sum = 0.0f;
|
||||
n_tokens = 0;
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
@ -155,41 +126,19 @@ std::string transcribe(
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 5;
|
||||
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
||||
if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
|
||||
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
|
||||
} else {
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
}
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
@ -198,17 +147,19 @@ std::string transcribe(
|
||||
|
||||
result += text;
|
||||
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
if(token.plog > 0.0f) exit(0);
|
||||
logprob_min = std::min(logprob_min, token.plog);
|
||||
logprob_sum += token.plog;
|
||||
++n_tokens;
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
@ -299,7 +250,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
@ -467,9 +418,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
int n_tokens = 0;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
|
||||
@ -507,7 +456,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
@ -543,27 +492,18 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min0 = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum0 = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens0 = 0;
|
||||
int n_tokens = 0;
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
if (!params.prompt.empty()) {
|
||||
k_prompt = params.prompt;
|
||||
}
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@ -596,11 +536,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min0);
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
@ -621,30 +559,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= txt.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
@ -657,16 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
if (best_len == 0) {
|
||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||
} else {
|
||||
// cut the prompt from the decoded text
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
}
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
@ -734,36 +654,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
auto & grammar = params.grammar_parsed;
|
||||
if (file_exists(params.grammar.c_str())) {
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
} else {
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
@ -9,11 +9,6 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
|
||||
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
|
||||
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
|
||||
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
|
||||
{"q2_k", GGML_FTYPE_MOSTLY_Q2_K},
|
||||
{"q3_k", GGML_FTYPE_MOSTLY_Q3_K},
|
||||
{"q4_k", GGML_FTYPE_MOSTLY_Q4_K},
|
||||
{"q5_k", GGML_FTYPE_MOSTLY_Q5_K},
|
||||
{"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
|
||||
};
|
||||
|
||||
void ggml_print_ftypes(FILE * fp) {
|
||||
@ -53,15 +48,15 @@ bool ggml_common_quantize_0(
|
||||
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;
|
||||
case GGML_FTYPE_UNKNOWN:
|
||||
case GGML_FTYPE_ALL_F32:
|
||||
case GGML_FTYPE_MOSTLY_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q2_K:
|
||||
case GGML_FTYPE_MOSTLY_Q3_K:
|
||||
case GGML_FTYPE_MOSTLY_Q4_K:
|
||||
case GGML_FTYPE_MOSTLY_Q5_K:
|
||||
case GGML_FTYPE_MOSTLY_Q6_K:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
|
||||
return false;
|
||||
@ -172,17 +167,24 @@ bool ggml_common_quantize_0(
|
||||
|
||||
switch ((ggml_type) ttype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data());
|
||||
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
@ -190,6 +192,11 @@ bool ggml_common_quantize_0(
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
|
@ -1,423 +0,0 @@
|
||||
#include "grammar-parser.h"
|
||||
#include <cstdint>
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
|
||||
namespace grammar_parser {
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
// copied from whisper.cpp
|
||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||
uint8_t highbits = first_byte >> 4;
|
||||
int len = lookup[highbits];
|
||||
uint8_t mask = (1 << (8 - len)) - 1;
|
||||
uint32_t value = first_byte & mask;
|
||||
const char * end = src + len; // may overrun!
|
||||
const char * pos = src + 1;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||
return next_id;
|
||||
}
|
||||
|
||||
void add_rule(
|
||||
parse_state & state,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule) {
|
||||
if (state.rules.size() <= rule_id) {
|
||||
state.rules.resize(rule_id + 1);
|
||||
}
|
||||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
const char * pos = src;
|
||||
const char * end = src + size;
|
||||
uint32_t value = 0;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value <<= 4;
|
||||
char c = *pos;
|
||||
if ('a' <= c && c <= 'f') {
|
||||
value += c - 'a' + 10;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
value += c - 'A' + 10;
|
||||
} else if ('0' <= c && c <= '9') {
|
||||
value += c - '0';
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos != end) {
|
||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
const char * parse_space(const char * src, bool newline_ok) {
|
||||
const char * pos = src;
|
||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||
if (*pos == '#') {
|
||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||
pos++;
|
||||
}
|
||||
} else {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_name(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_word_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
case 'x': return parse_hex(src + 2, 2);
|
||||
case 'u': return parse_hex(src + 2, 4);
|
||||
case 'U': return parse_hex(src + 2, 8);
|
||||
case 't': return std::make_pair('\t', src + 2);
|
||||
case 'r': return std::make_pair('\r', src + 2);
|
||||
case 'n': return std::make_pair('\n', src + 2);
|
||||
case '\\':
|
||||
case '"':
|
||||
case '[':
|
||||
case ']':
|
||||
return std::make_pair(src[1], src + 2);
|
||||
default:
|
||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||
}
|
||||
} else if (*src) {
|
||||
return decode_utf8(src);
|
||||
}
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_sequence(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
std::vector<whisper_grammar_element> & out_elements,
|
||||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != '"') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '[') { // char range(s)
|
||||
pos++;
|
||||
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
|
||||
if (*pos == '^') {
|
||||
pos++;
|
||||
start_type = WHISPER_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != ']') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
enum whisper_gretype type = last_sym_start < out_elements.size()
|
||||
? WHISPER_GRETYPE_CHAR_ALT
|
||||
: start_type;
|
||||
|
||||
out_elements.push_back({type, char_pair.first});
|
||||
if (pos[0] == '-' && pos[1] != ']') {
|
||||
auto endchar_pair = parse_char(pos + 1);
|
||||
pos = endchar_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||
last_sym_start = out_elements.size();
|
||||
// output reference to synthesized rule
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
if (*pos != ')') {
|
||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// rewrite rules:
|
||||
// S* --> S' ::= S S' |
|
||||
// S+ --> S' ::= S S' | S
|
||||
// S? --> S' ::= S |
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<whisper_grammar_element> sub_rule;
|
||||
// add preceding symbol to generated rule
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (*pos == '*' || *pos == '+') {
|
||||
// cause generated rule to recurse
|
||||
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
}
|
||||
// mark start of alternate def
|
||||
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
if (*pos == '+') {
|
||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
}
|
||||
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested) {
|
||||
std::vector<whisper_grammar_element> rule;
|
||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||
while (*pos == '|') {
|
||||
rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
pos = parse_space(pos + 1, true);
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||
}
|
||||
rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, rule_id, rule);
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_rule(parse_state & state, const char * src) {
|
||||
const char * name_end = parse_name(src);
|
||||
const char * pos = parse_space(name_end, false);
|
||||
size_t name_len = name_end - src;
|
||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||
const std::string name(src, name_len);
|
||||
|
||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 3, true);
|
||||
|
||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||
|
||||
if (*pos == '\r') {
|
||||
pos += pos[1] == '\n' ? 2 : 1;
|
||||
} else if (*pos == '\n') {
|
||||
pos++;
|
||||
} else if (*pos) {
|
||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||
}
|
||||
return parse_space(pos, true);
|
||||
}
|
||||
|
||||
parse_state parse(const char * src) {
|
||||
try {
|
||||
parse_state state;
|
||||
const char * pos = parse_space(src, true);
|
||||
while (*pos) {
|
||||
pos = parse_rule(state, pos);
|
||||
}
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
}
|
||||
}
|
||||
|
||||
void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
} else {
|
||||
// cop out of encoding UTF-8
|
||||
fprintf(file, "<U+%04X>", c);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_char_element(whisper_grammar_element elem) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_CHAR: return true;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
|
||||
for (auto elem : rule) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
|
||||
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "(%u) ", elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
fprintf(file, "(\"");
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_rule(
|
||||
FILE * file,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule,
|
||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
|
||||
throw std::runtime_error(
|
||||
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
|
||||
}
|
||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
whisper_grammar_element elem = rule[i];
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
throw std::runtime_error(
|
||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||
std::to_string(i));
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
fprintf(file, "| ");
|
||||
break;
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
fprintf(file, "[");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
fprintf(file, "[^");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
fprintf(file, "-");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
break;
|
||||
default:
|
||||
fprintf(file, "] ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_grammar(FILE * file, const parse_state & state) {
|
||||
try {
|
||||
std::map<uint32_t, std::string> symbol_id_names;
|
||||
for (auto kv : state.symbol_ids) {
|
||||
symbol_id_names[kv.second] = kv.first;
|
||||
}
|
||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||
// fprintf(file, "%zu: ", i);
|
||||
// print_rule_binary(file, state.rules[i]);
|
||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
||||
// fprintf(file, "\n");
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
||||
std::vector<const whisper_grammar_element *> ret;
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||
// binary context-free grammar format specified by whisper.h. Supports character
|
||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||
// arithmetic might look like:
|
||||
//
|
||||
// root ::= expr
|
||||
// expr ::= term ([-+*/] term)*
|
||||
// term ::= num | "(" space expr ")" space
|
||||
// num ::= [0-9]+ space
|
||||
// space ::= [ \t\n]*
|
||||
|
||||
#pragma once
|
||||
#include "whisper.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace grammar_parser {
|
||||
struct parse_state {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
|
||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
void print_grammar(FILE * file, const parse_state & state);
|
||||
}
|
@ -48,7 +48,7 @@ if [ -n "$3" ]; then
|
||||
fi
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
|
@ -62,8 +62,8 @@ struct whisper_params {
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
|
@ -162,7 +162,6 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
|
||||
"encoder.conv2.bias",
|
||||
"encoder.positional_embedding",
|
||||
"decoder.positional_embedding",
|
||||
"decoder.*",
|
||||
};
|
||||
|
||||
if (!ggml_common_quantize_0(finp, fout, ftype, { ".*" }, to_skip)) {
|
||||
|
@ -53,7 +53,6 @@ struct whisper_params {
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
int32_t n_gpu_layers = 999;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
@ -91,7 +90,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
@ -136,7 +134,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
@ -271,8 +268,6 @@ int main(int argc, char ** argv) {
|
||||
auto lmparams = llama_model_default_params();
|
||||
if (!params.use_gpu) {
|
||||
lmparams.n_gpu_layers = 0;
|
||||
} else {
|
||||
lmparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
|
||||
@ -686,8 +681,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, "'", "'\"'\"'");
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + text_to_speak + "'").c_str());
|
||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "%s: failed to speak\n", __func__);
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ help()
|
||||
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
|
||||
echo "options:"
|
||||
echo "-s Step in seconds (default is $step)."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large-v3' (default is '$model')."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large' (default is '$model')."
|
||||
echo "-t Number of threads to use."
|
||||
echo "-h Print this help page."
|
||||
echo
|
||||
|
15
examples/whisper.android.java/.gitignore
vendored
@ -1,15 +0,0 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
@ -1,20 +0,0 @@
|
||||
A sample Android app using java code and [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Modify the modelFilePath in the WhisperService.java
|
||||
6. Modify the sampleFilePath in the WhisperService.java
|
||||
7. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
PS:
|
||||
1. Do not move this android project folder individually to other folders, because this android project folder depends on the files of the whole project.
|
||||
2. The cpp code is compiled during the build process
|
||||
3. If you want to import a compiled cpp project in your Android project, please refer to the https://github.com/litongjava/whisper.cpp.android.java.demo
|
||||
|
||||

|
||||
|
Before Width: | Height: | Size: 67 KiB |
1
examples/whisper.android.java/app/.gitignore
vendored
@ -1 +0,0 @@
|
||||
/build
|
@ -1,58 +0,0 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
}
|
||||
|
||||
android {
|
||||
compileSdkVersion 30
|
||||
buildToolsVersion '30.0.3'
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.litongjava.whisper.android.java"
|
||||
minSdkVersion 21
|
||||
targetSdkVersion 30
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
cppFlags ""
|
||||
}
|
||||
}
|
||||
ndk {
|
||||
abiFilters 'arm64-v8a', 'armeabi-v7a', 'x86', 'x86_64'
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.debug
|
||||
minifyEnabled true
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "src/main/jni/whisper/CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
ndkVersion "25.2.9519653"
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.appcompat:appcompat:1.1.0'
|
||||
implementation 'com.google.android.material:material:1.1.0'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
|
||||
testImplementation 'junit:junit:4.+'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
||||
|
||||
//litongjava
|
||||
implementation 'com.litongjava:android-view-inject:1.0'
|
||||
implementation 'com.litongjava:jfinal-aop:1.0.1'
|
||||
implementation 'com.litongjava:litongjava-android-utils:1.0.0'
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
@ -1,26 +0,0 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
|
||||
*/
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class ExampleInstrumentedTest {
|
||||
@Test
|
||||
public void useAppContext() {
|
||||
// Context of the app under test.
|
||||
Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
assertEquals("com.litongjava.whisper.android.java", appContext.getPackageName());
|
||||
}
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.litongjava.whisper.android.java">
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:name=".app.App"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.Whisperandroidjava">
|
||||
<activity android:name=".MainActivity">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -1,40 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<configuration debug="false" xmlns="http://ch.qos.logback/xml/ns/logback"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://ch.qos.logback/xml/ns/logback https://raw.githubusercontent.com/enricopulatzo/logback-XSD/master/src/main/xsd/logback.xsd
|
||||
http://ch.qos.logback/xml/ns/logback ">
|
||||
<!--Define the storage address of the log file Do not use relative paths in the LogBack configuration. -->
|
||||
<property name="LOG_HOME" value="logs" />
|
||||
<!--Formatted output: %d means the date, %-6level: log level from the left display 6 characters wide, %m: log message, %n is a newline character -->
|
||||
<property name="CONSOLE_LOG_PATTERN"
|
||||
value="%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-6level%logger{0}.%M:%L - %m%n" />
|
||||
|
||||
<!-- console output -->
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
|
||||
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<!-- Generate log files on a daily basis -->
|
||||
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
|
||||
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
|
||||
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
|
||||
</encoder>
|
||||
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
|
||||
<!--File name for log file output -->
|
||||
<fileNamePattern>${LOG_HOME}/project-name-%d{yyyy-MM-dd}.log</fileNamePattern>
|
||||
<!--Maximum size of log file -->
|
||||
<maxHistory>180</maxHistory>
|
||||
</rollingPolicy>
|
||||
<!--日志文件最大的大小 -->
|
||||
<triggeringPolicy class="ch.qos.logback.core.rolling.SizeBasedTriggeringPolicy">
|
||||
<maxFileSize>10MB</maxFileSize>
|
||||
</triggeringPolicy>
|
||||
</appender>
|
||||
<!-- Log output level and source-->
|
||||
<root level="info">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
</configuration>
|
@ -1,107 +0,0 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.view.View;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.android.view.inject.annotation.FindViewById;
|
||||
import com.litongjava.android.view.inject.annotation.FindViewByIdLayout;
|
||||
import com.litongjava.android.view.inject.annotation.OnClick;
|
||||
import com.litongjava.android.view.inject.utils.ViewInjectUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.jfinal.aop.AopManager;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
import com.litongjava.whisper.android.java.task.LoadModelTask;
|
||||
import com.litongjava.whisper.android.java.task.TranscriptionTask;
|
||||
import com.litongjava.whisper.android.java.utils.AssetUtils;
|
||||
import com.whispercpp.java.whisper.WhisperLib;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
|
||||
@FindViewByIdLayout(R.layout.activity_main)
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
|
||||
@FindViewById(R.id.sample_text)
|
||||
private TextView tv;
|
||||
|
||||
Logger log = LoggerFactory.getLogger(this.getClass());
|
||||
private WhisperService whisperService = Aop.get(WhisperService.class);
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
//setContentView(R.layout.activity_main);
|
||||
ViewInjectUtils.injectActivity(this, this);
|
||||
initAopBean();
|
||||
showSystemInfo();
|
||||
}
|
||||
|
||||
private void initAopBean() {
|
||||
Handler mainHandler = new Handler(Looper.getMainLooper());
|
||||
AopManager.me().addSingletonObject(mainHandler);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@OnClick(R.id.loadModelBtn)
|
||||
public void loadModelBtn_OnClick(View v) {
|
||||
Context context = getBaseContext();
|
||||
ThreadUtils.executeByIo(new LoadModelTask(tv));
|
||||
}
|
||||
|
||||
@OnClick(R.id.transcriptSampleBtn)
|
||||
public void transcriptSampleBtn_OnClick(View v) {
|
||||
Context context = getBaseContext();
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
String sampleFilePath = "samples/jfk.wav";
|
||||
File filesDir = context.getFilesDir();
|
||||
File sampleFile = AssetUtils.copyFileIfNotExists(context, filesDir, sampleFilePath);
|
||||
long end = System.currentTimeMillis();
|
||||
String msg = "copy file:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
ThreadUtils.executeByIo(new TranscriptionTask(tv, sampleFile));
|
||||
}
|
||||
|
||||
private void outputMsg(TextView tv, String msg) {
|
||||
tv.append(msg + "\n");
|
||||
log.info(msg);
|
||||
}
|
||||
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@OnClick(R.id.systemInfoBtn)
|
||||
public void systemInfoBtn_OnClick(View v) {
|
||||
showSystemInfo();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void showSystemInfo() {
|
||||
String systemInfo = WhisperLib.getSystemInfo();
|
||||
tv.append(systemInfo + "\n");
|
||||
}
|
||||
|
||||
@OnClick(R.id.clearBtn)
|
||||
public void clearBtn_OnClick(View v) {
|
||||
tv.setText("");
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
protected void onDestroy() {
|
||||
super.onDestroy();
|
||||
whisperService.release();
|
||||
}
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.app;
|
||||
|
||||
import android.app.Application;
|
||||
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
|
||||
public class App extends Application {
|
||||
@Override
|
||||
public void onCreate() {
|
||||
super.onCreate();
|
||||
Utils.init(this);
|
||||
}
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "["+start+" --> "+end+"]:"+sentence;
|
||||
}
|
||||
}
|
@ -1,101 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.services;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.blankj.utilcode.util.ToastUtils;
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
import com.litongjava.android.utils.dialog.AlertDialogUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
import com.litongjava.whisper.android.java.single.LocalWhisper;
|
||||
import com.litongjava.whisper.android.java.utils.WaveEncoder;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
public class WhisperService {
|
||||
private Logger log = LoggerFactory.getLogger(this.getClass());
|
||||
|
||||
private final Object lock = new Object();
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void loadModel(TextView tv) {
|
||||
String modelFilePath = LocalWhisper.modelFilePath;
|
||||
String msg = "load model from :" + modelFilePath + "\n";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
LocalWhisper.INSTANCE.init();
|
||||
long end = System.currentTimeMillis();
|
||||
msg = "model load successful:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
ToastUtils.showLong(msg);
|
||||
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void transcribeSample(TextView tv, File sampleFile) {
|
||||
String msg = "";
|
||||
msg = "transcribe file from :" + sampleFile.getAbsolutePath();
|
||||
outputMsg(tv, msg);
|
||||
|
||||
Long start = System.currentTimeMillis();
|
||||
float[] audioData = new float[0]; // 读取音频样本
|
||||
try {
|
||||
audioData = WaveEncoder.decodeWaveFile(sampleFile);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return;
|
||||
}
|
||||
long end = System.currentTimeMillis();
|
||||
msg = "decode wave file:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
start = System.currentTimeMillis();
|
||||
List<WhisperSegment> transcription = null;
|
||||
try {
|
||||
//transcription = LocalWhisper.INSTANCE.transcribeData(audioData);
|
||||
transcription = LocalWhisper.INSTANCE.transcribeDataWithTime(audioData);
|
||||
} catch (ExecutionException e) {
|
||||
e.printStackTrace();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
end = System.currentTimeMillis();
|
||||
if(transcription!=null){
|
||||
ToastUtils.showLong(transcription.toString());
|
||||
msg = "Transcript successful:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
outputMsg(tv, transcription.toString());
|
||||
|
||||
}else{
|
||||
msg = "Transcript failed:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void outputMsg(TextView tv, String msg) {
|
||||
log.info(msg);
|
||||
if(tv!=null){
|
||||
Aop.get(Handler.class).post(()->{ tv.append(msg + "\n");});
|
||||
}
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void release() {
|
||||
//noting to do
|
||||
}
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.single;
|
||||
|
||||
import android.app.Application;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.blankj.utilcode.util.ToastUtils;
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
import com.litongjava.whisper.android.java.utils.AssetUtils;
|
||||
import com.whispercpp.java.whisper.WhisperContext;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public enum LocalWhisper {
|
||||
INSTANCE;
|
||||
|
||||
public static final String modelFilePath = "models/ggml-tiny.bin";
|
||||
private WhisperContext whisperContext;
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
LocalWhisper() {
|
||||
Application context = Utils.getApp();
|
||||
File filesDir = context.getFilesDir();
|
||||
File modelFile = AssetUtils.copyFileIfNotExists(context, filesDir, modelFilePath);
|
||||
String realModelFilePath = modelFile.getAbsolutePath();
|
||||
whisperContext = WhisperContext.createContextFromFile(realModelFilePath);
|
||||
}
|
||||
|
||||
public synchronized String transcribeData(float[] data) throws ExecutionException, InterruptedException {
|
||||
if(whisperContext==null){
|
||||
toastModelLoading();
|
||||
return null;
|
||||
}else{
|
||||
return whisperContext.transcribeData(data);
|
||||
}
|
||||
}
|
||||
|
||||
private static void toastModelLoading() {
|
||||
Aop.get(Handler.class).post(()->{
|
||||
ToastUtils.showShort("please wait for model loading");
|
||||
});
|
||||
}
|
||||
|
||||
public List<WhisperSegment> transcribeDataWithTime(float[] audioData) throws ExecutionException, InterruptedException {
|
||||
if(whisperContext==null){
|
||||
toastModelLoading();
|
||||
return null;
|
||||
}else{
|
||||
return whisperContext.transcribeDataWithTime(audioData);
|
||||
}
|
||||
}
|
||||
|
||||
public void init() {
|
||||
//noting to do.but init
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.task;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class LoadModelTask extends ThreadUtils.Task<Object> {
|
||||
private final TextView tv;
|
||||
public LoadModelTask(TextView tv) {
|
||||
this.tv = tv;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doInBackground() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
Aop.get(WhisperService.class).loadModel(tv);
|
||||
}else{
|
||||
Aop.get(Handler.class).post(()->{
|
||||
tv.append("not supported android devices");
|
||||
});
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(Object result) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFail(Throwable t) {
|
||||
}
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.task;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class TranscriptionTask extends ThreadUtils.Task<Object> {
|
||||
private final TextView tv;
|
||||
private final File sampleFile;
|
||||
|
||||
public TranscriptionTask(TextView tv, File sampleFile) {
|
||||
this.tv = tv;
|
||||
this.sampleFile = sampleFile;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doInBackground() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
Aop.get(WhisperService.class).transcribeSample(tv, sampleFile);
|
||||
}else{
|
||||
tv.append("not supported android devices");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(Object result) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFail(Throwable t) {
|
||||
}
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.utils;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
public class AssetUtils {
|
||||
private static Logger log = LoggerFactory.getLogger(AssetUtils.class);
|
||||
|
||||
public static File copyFileIfNotExists(Context context, File distDir, String filename) {
|
||||
File dstFile = new File(distDir, filename);
|
||||
if (dstFile.exists()) {
|
||||
return dstFile;
|
||||
} else {
|
||||
File parentFile = dstFile.getParentFile();
|
||||
log.info("parentFile:{}", parentFile);
|
||||
if (!parentFile.exists()) {
|
||||
parentFile.mkdirs();
|
||||
}
|
||||
AssetUtils.copyFileFromAssets(context, filename, dstFile);
|
||||
}
|
||||
return dstFile;
|
||||
}
|
||||
|
||||
public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
|
||||
if (srcDir.isEmpty() || dstDir.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (!new File(dstDir).exists()) {
|
||||
new File(dstDir).mkdirs();
|
||||
}
|
||||
for (String fileName : appCtx.getAssets().list(srcDir)) {
|
||||
String srcSubPath = srcDir + File.separator + fileName;
|
||||
String dstSubPath = dstDir + File.separator + fileName;
|
||||
if (new File(srcSubPath).isDirectory()) {
|
||||
copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
|
||||
} else {
|
||||
copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
|
||||
File dstFile = new File(dstPath);
|
||||
copyFileFromAssets(appCtx, srcPath, dstFile);
|
||||
}
|
||||
|
||||
public static void copyFileFromAssets(Context appCtx, String srcPath, File dstFile) {
|
||||
if (srcPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
InputStream is = null;
|
||||
OutputStream os = null;
|
||||
try {
|
||||
is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
|
||||
|
||||
os = new BufferedOutputStream(new FileOutputStream(dstFile));
|
||||
byte[] buffer = new byte[1024];
|
||||
int length = 0;
|
||||
while ((length = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, length);
|
||||
}
|
||||
} catch (FileNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
try {
|
||||
os.close();
|
||||
is.close();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.utils;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.ShortBuffer;
|
||||
|
||||
public class WaveEncoder {
|
||||
|
||||
public static float[] decodeWaveFile(File file) throws IOException {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
try (FileInputStream fis = new FileInputStream(file)) {
|
||||
byte[] buffer = new byte[1024];
|
||||
int bytesRead;
|
||||
while ((bytesRead = fis.read(buffer)) != -1) {
|
||||
baos.write(buffer, 0, bytesRead);
|
||||
}
|
||||
}
|
||||
ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
|
||||
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
int channel = byteBuffer.getShort(22);
|
||||
byteBuffer.position(44);
|
||||
|
||||
ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
|
||||
short[] shortArray = new short[shortBuffer.limit()];
|
||||
shortBuffer.get(shortArray);
|
||||
|
||||
float[] output = new float[shortArray.length / channel];
|
||||
|
||||
for (int index = 0; index < output.length; index++) {
|
||||
if (channel == 1) {
|
||||
output[index] = Math.max(-1f, Math.min(1f, shortArray[index] / 32767.0f));
|
||||
} else {
|
||||
output[index] = Math.max(-1f, Math.min(1f, (shortArray[2 * index] + shortArray[2 * index + 1]) / 32767.0f / 2.0f));
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
public static void encodeWaveFile(File file, short[] data) throws IOException {
|
||||
try (FileOutputStream fos = new FileOutputStream(file)) {
|
||||
fos.write(headerBytes(data.length * 2));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(data.length * 2);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.asShortBuffer().put(data);
|
||||
|
||||
byte[] bytes = new byte[buffer.limit()];
|
||||
buffer.get(bytes);
|
||||
|
||||
fos.write(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] headerBytes(int totalLength) {
|
||||
if (totalLength < 44)
|
||||
throw new IllegalArgumentException("Total length must be at least 44 bytes");
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(44);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
buffer.put((byte) 'R');
|
||||
buffer.put((byte) 'I');
|
||||
buffer.put((byte) 'F');
|
||||
buffer.put((byte) 'F');
|
||||
|
||||
buffer.putInt(totalLength - 8);
|
||||
|
||||
buffer.put((byte) 'W');
|
||||
buffer.put((byte) 'A');
|
||||
buffer.put((byte) 'V');
|
||||
buffer.put((byte) 'E');
|
||||
|
||||
buffer.put((byte) 'f');
|
||||
buffer.put((byte) 'm');
|
||||
buffer.put((byte) 't');
|
||||
buffer.put((byte) ' ');
|
||||
|
||||
buffer.putInt(16);
|
||||
buffer.putShort((short) 1);
|
||||
buffer.putShort((short) 1);
|
||||
buffer.putInt(16000);
|
||||
buffer.putInt(32000);
|
||||
buffer.putShort((short) 2);
|
||||
buffer.putShort((short) 16);
|
||||
|
||||
buffer.put((byte) 'd');
|
||||
buffer.put((byte) 'a');
|
||||
buffer.put((byte) 't');
|
||||
buffer.put((byte) 'a');
|
||||
|
||||
buffer.putInt(totalLength - 44);
|
||||
buffer.position(0);
|
||||
|
||||
byte[] bytes = new byte[buffer.limit()];
|
||||
buffer.get(bytes);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class CpuInfo {
|
||||
private static final String LOG_TAG = "WhisperCpuConfig";
|
||||
|
||||
private List<String> lines;
|
||||
|
||||
public CpuInfo(List<String> lines) {
|
||||
this.lines = lines;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public int getHighPerfCpuCount0() {
|
||||
try {
|
||||
return getHighPerfCpuCountByFrequencies();
|
||||
} catch (Exception e) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e);
|
||||
return getHighPerfCpuCountByVariant();
|
||||
}
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int getHighPerfCpuCountByFrequencies() {
|
||||
List<Integer> frequencies = getCpuValues("processor", line -> {
|
||||
try {
|
||||
return getMaxCpuFrequency(Integer.parseInt(line.trim()));
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
);
|
||||
Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): " + binnedValues(frequencies));
|
||||
return countDroppingMin(frequencies);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int getHighPerfCpuCountByVariant() {
|
||||
List<Integer> variants = getCpuValues("CPU variant", line -> Integer.parseInt(line.trim().substring(line.indexOf("0x") + 2), 16));
|
||||
Log.d(LOG_TAG, "Binned cpu variants (variant, count): " + binnedValues(variants));
|
||||
return countKeepingMin(variants);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private Map<Integer, Integer> binnedValues(List<Integer> values) {
|
||||
Map<Integer, Integer> countMap = new HashMap<>();
|
||||
for (int value : values) {
|
||||
countMap.put(value, countMap.getOrDefault(value, 0) + 1);
|
||||
}
|
||||
return countMap;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private List<Integer> getCpuValues(String property, Mapper mapper) {
|
||||
List<Integer> values = new ArrayList<>();
|
||||
for (String line : lines) {
|
||||
if (line.startsWith(property)) {
|
||||
values.add(mapper.map(line.substring(line.indexOf(':') + 1)));
|
||||
}
|
||||
}
|
||||
values.sort(Integer::compareTo);
|
||||
return values;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int countDroppingMin(List<Integer> values) {
|
||||
int min = values.stream().mapToInt(i -> i).min().orElse(Integer.MAX_VALUE);
|
||||
return (int) values.stream().filter(value -> value > min).count();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int countKeepingMin(List<Integer> values) {
|
||||
int min = values.stream().mapToInt(i -> i).min().orElse(Integer.MAX_VALUE);
|
||||
return (int) values.stream().filter(value -> value.equals(min)).count();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public static int getHighPerfCpuCount() {
|
||||
try {
|
||||
return readCpuInfo().getHighPerfCpuCount0();
|
||||
} catch (Exception e) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU info", e);
|
||||
return Math.max(Runtime.getRuntime().availableProcessors() - 4, 0);
|
||||
}
|
||||
}
|
||||
|
||||
private static CpuInfo readCpuInfo() throws IOException {
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader("/proc/cpuinfo"))) {
|
||||
List<String> lines = new ArrayList<>();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
lines.add(line);
|
||||
}
|
||||
return new CpuInfo(lines);
|
||||
}
|
||||
}
|
||||
|
||||
private static int getMaxCpuFrequency(int cpuIndex) throws IOException {
|
||||
String path = "/sys/devices/system/cpu/cpu" + cpuIndex + "/cpufreq/cpuinfo_max_freq";
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
return Integer.parseInt(reader.readLine());
|
||||
}
|
||||
}
|
||||
|
||||
private interface Mapper {
|
||||
int map(String line);
|
||||
}
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public class WhisperContext {
|
||||
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
private long ptr;
|
||||
private final ExecutorService executorService;
|
||||
|
||||
private WhisperContext(long ptr) {
|
||||
this.ptr = ptr;
|
||||
this.executorService = Executors.newSingleThreadExecutor();
|
||||
}
|
||||
|
||||
public String transcribeData(float[] data) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(new Callable<String>() {
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
public String call() throws Exception {
|
||||
if (ptr == 0L) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
int numThreads = WhisperCpuConfig.getPreferredThreadCount();
|
||||
Log.d(LOG_TAG, "Selecting " + numThreads + " threads");
|
||||
|
||||
StringBuilder result = new StringBuilder();
|
||||
synchronized (this) {
|
||||
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data);
|
||||
int textCount = WhisperLib.getTextSegmentCount(ptr);
|
||||
for (int i = 0; i < textCount; i++) {
|
||||
String sentence = WhisperLib.getTextSegment(ptr, i);
|
||||
result.append(sentence);
|
||||
}
|
||||
}
|
||||
return result.toString();
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
public List<WhisperSegment> transcribeDataWithTime(float[] data) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(new Callable<List<WhisperSegment>>() {
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
public List<WhisperSegment> call() throws Exception {
|
||||
if (ptr == 0L) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
int numThreads = WhisperCpuConfig.getPreferredThreadCount();
|
||||
Log.d(LOG_TAG, "Selecting " + numThreads + " threads");
|
||||
|
||||
List<WhisperSegment> segments = new ArrayList<>();
|
||||
synchronized (this) {
|
||||
// StringBuilder result = new StringBuilder();
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data);
|
||||
int textCount = WhisperLib.getTextSegmentCount(ptr);
|
||||
for (int i = 0; i < textCount; i++) {
|
||||
long start = WhisperLib.getTextSegmentT0(ptr, i);
|
||||
String sentence = WhisperLib.getTextSegment(ptr, i);
|
||||
long end = WhisperLib.getTextSegmentT1(ptr, i);
|
||||
// result.append();
|
||||
segments.add(new WhisperSegment(start, end, sentence));
|
||||
|
||||
}
|
||||
// return result.toString();
|
||||
}
|
||||
return segments;
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public String benchMemory(int nthreads) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(() -> WhisperLib.benchMemcpy(nthreads)).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public String benchGgmlMulMat(int nthreads) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(() -> WhisperLib.benchGgmlMulMat(nthreads)).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void release() throws ExecutionException, InterruptedException {
|
||||
executorService.submit(() -> {
|
||||
if (ptr != 0L) {
|
||||
WhisperLib.freeContext(ptr);
|
||||
ptr = 0;
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromFile(String filePath) {
|
||||
long ptr = WhisperLib.initContext(filePath);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context with path " + filePath);
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromInputStream(InputStream stream) {
|
||||
long ptr = WhisperLib.initContextFromInputStream(stream);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context from input stream");
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromAsset(AssetManager assetManager, String assetPath) {
|
||||
long ptr = WhisperLib.initContextFromAsset(assetManager, assetPath);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context from asset " + assetPath);
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static String getSystemInfo() {
|
||||
return WhisperLib.getSystemInfo();
|
||||
}
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
public class WhisperCpuConfig {
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public static int getPreferredThreadCount() {
|
||||
return Math.max(CpuInfo.getHighPerfCpuCount(), 2);
|
||||
}
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public class WhisperLib {
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
|
||||
static {
|
||||
|
||||
Log.d(LOG_TAG, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
|
||||
boolean loadVfpv4 = false;
|
||||
boolean loadV8fp16 = false;
|
||||
if (WhisperUtils.isArmEabiV7a()) {
|
||||
String cpuInfo = WhisperUtils.cpuInfo();
|
||||
if (cpuInfo != null) {
|
||||
Log.d(LOG_TAG, "CPU info: " + cpuInfo);
|
||||
if (cpuInfo.contains("vfpv4")) {
|
||||
Log.d(LOG_TAG, "CPU supports vfpv4");
|
||||
loadVfpv4 = true;
|
||||
}
|
||||
}
|
||||
} else if (WhisperUtils.isArmEabiV8a()) {
|
||||
String cpuInfo = WhisperUtils.cpuInfo();
|
||||
if (cpuInfo != null) {
|
||||
Log.d(LOG_TAG, "CPU info: " + cpuInfo);
|
||||
if (cpuInfo.contains("fphp")) {
|
||||
Log.d(LOG_TAG, "CPU supports fp16 arithmetic");
|
||||
loadV8fp16 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (loadVfpv4) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so");
|
||||
System.loadLibrary("whisper_vfpv4");
|
||||
} else if (loadV8fp16) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so");
|
||||
System.loadLibrary("whisper_v8fp16_va");
|
||||
} else {
|
||||
Log.d(LOG_TAG, "Loading libwhisper.so");
|
||||
System.loadLibrary("whisper");
|
||||
}
|
||||
}
|
||||
|
||||
public static native long initContextFromInputStream(InputStream inputStream);
|
||||
|
||||
public static native long initContextFromAsset(AssetManager assetManager, String assetPath);
|
||||
|
||||
public static native long initContext(String modelPath);
|
||||
|
||||
public static native void freeContext(long contextPtr);
|
||||
|
||||
public static native void fullTranscribe(long contextPtr, int numThreads, float[] audioData);
|
||||
|
||||
public static native int getTextSegmentCount(long contextPtr);
|
||||
|
||||
public static native String getTextSegment(long contextPtr, int index);
|
||||
|
||||
public static native long getTextSegmentT0(long contextPtr, int index);
|
||||
|
||||
public static native long getTextSegmentT1(long contextPtr, int index);
|
||||
|
||||
public static native String getSystemInfo();
|
||||
|
||||
public static native String benchMemcpy(int nthread);
|
||||
|
||||
public static native String benchGgmlMulMat(int nthread);
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class WhisperUtils {
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
|
||||
|
||||
public static boolean isArmEabiV7a() {
|
||||
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a");
|
||||
}
|
||||
|
||||
public static boolean isArmEabiV8a() {
|
||||
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static String cpuInfo() {
|
||||
try {
|
||||
Path path = new File("/proc/cpuinfo").toPath();
|
||||
return new String(java.nio.file.Files.readAllBytes(path));
|
||||
} catch (Exception e) {
|
||||
Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,56 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
|
||||
project(whisper.cpp)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
|
||||
|
||||
set(
|
||||
SOURCE_FILES
|
||||
${WHISPER_LIB_DIR}/ggml.c
|
||||
${WHISPER_LIB_DIR}/ggml-alloc.c
|
||||
${WHISPER_LIB_DIR}/ggml-backend.c
|
||||
${WHISPER_LIB_DIR}/ggml-quants.c
|
||||
${WHISPER_LIB_DIR}/whisper.cpp
|
||||
${CMAKE_SOURCE_DIR}/jni.c
|
||||
)
|
||||
|
||||
find_library(LOG_LIB log)
|
||||
|
||||
function(build_library target_name)
|
||||
add_library(
|
||||
${target_name}
|
||||
SHARED
|
||||
${SOURCE_FILES}
|
||||
)
|
||||
|
||||
target_link_libraries(${target_name} ${LOG_LIB} android)
|
||||
|
||||
if (${target_name} STREQUAL "whisper_v8fp16_va")
|
||||
target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16)
|
||||
elseif (${target_name} STREQUAL "whisper_vfpv4")
|
||||
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
|
||||
endif ()
|
||||
|
||||
if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
|
||||
target_compile_options(${target_name} PRIVATE -O3)
|
||||
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
|
||||
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
|
||||
|
||||
#target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
|
||||
#target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
|
||||
#target_link_options(${target_name} PRIVATE -flto)
|
||||
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
build_library("whisper") # Default target
|
||||
|
||||
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
|
||||
build_library("whisper_v8fp16_va")
|
||||
elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a")
|
||||
build_library("whisper_vfpv4")
|
||||
endif ()
|
||||
|
||||
include_directories(${WHISPER_LIB_DIR})
|
@ -1,257 +0,0 @@
|
||||
#include <jni.h>
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <string.h>
|
||||
#include "whisper.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define TAG "JNI"
|
||||
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
|
||||
static inline int min(int a, int b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
static inline int max(int a, int b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
struct input_stream_context {
|
||||
size_t offset;
|
||||
JNIEnv * env;
|
||||
jobject thiz;
|
||||
jobject input_stream;
|
||||
|
||||
jmethodID mid_available;
|
||||
jmethodID mid_read;
|
||||
};
|
||||
|
||||
size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
|
||||
|
||||
jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
|
||||
|
||||
jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
|
||||
|
||||
if (size_to_copy != read_size || size_to_copy != n_read) {
|
||||
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
|
||||
}
|
||||
|
||||
jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
|
||||
memcpy(output, byte_array_elements, size_to_copy);
|
||||
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
|
||||
|
||||
(*is->env)->DeleteLocalRef(is->env, byte_array);
|
||||
|
||||
is->offset += size_to_copy;
|
||||
|
||||
return size_to_copy;
|
||||
}
|
||||
bool inputStreamEof(void * ctx) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
return result <= 0;
|
||||
}
|
||||
void inputStreamClose(void * ctx) {
|
||||
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContextFromInputStream(
|
||||
JNIEnv *env, jobject thiz, jobject input_stream) {
|
||||
UNUSED(thiz);
|
||||
|
||||
struct whisper_context *context = NULL;
|
||||
struct whisper_model_loader loader = {};
|
||||
struct input_stream_context inp_ctx = {};
|
||||
|
||||
inp_ctx.offset = 0;
|
||||
inp_ctx.env = env;
|
||||
inp_ctx.thiz = thiz;
|
||||
inp_ctx.input_stream = input_stream;
|
||||
|
||||
jclass cls = (*env)->GetObjectClass(env, input_stream);
|
||||
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
|
||||
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
|
||||
|
||||
loader.context = &inp_ctx;
|
||||
loader.read = inputStreamRead;
|
||||
loader.eof = inputStreamEof;
|
||||
loader.close = inputStreamClose;
|
||||
|
||||
loader.eof(loader.context);
|
||||
|
||||
context = whisper_init(&loader);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
static size_t asset_read(void *ctx, void *output, size_t read_size) {
|
||||
return AAsset_read((AAsset *) ctx, output, read_size);
|
||||
}
|
||||
|
||||
static bool asset_is_eof(void *ctx) {
|
||||
return AAsset_getRemainingLength64((AAsset *) ctx) <= 0;
|
||||
}
|
||||
|
||||
static void asset_close(void *ctx) {
|
||||
AAsset_close((AAsset *) ctx);
|
||||
}
|
||||
|
||||
static struct whisper_context *whisper_init_from_asset(
|
||||
JNIEnv *env,
|
||||
jobject assetManager,
|
||||
const char *asset_path
|
||||
) {
|
||||
LOGI("Loading model from asset '%s'\n", asset_path);
|
||||
AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
|
||||
AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING);
|
||||
if (!asset) {
|
||||
LOGW("Failed to open '%s'\n", asset_path);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_model_loader loader = {
|
||||
.context = asset,
|
||||
.read = &asset_read,
|
||||
.eof = &asset_is_eof,
|
||||
.close = &asset_close
|
||||
};
|
||||
|
||||
return whisper_init(&loader);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContextFromAsset(
|
||||
JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL);
|
||||
context = whisper_init_from_asset(env, assetManager, asset_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContext(
|
||||
JNIEnv *env, jobject thiz, jstring model_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
|
||||
context = whisper_init_from_file(model_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_freeContext(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
whisper_free(context);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_fullTranscribe(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||
|
||||
// The below adapted from the Objective-C iOS sample
|
||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
params.print_realtime = true;
|
||||
params.print_progress = false;
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = false;
|
||||
params.language = "en";
|
||||
params.n_threads = num_threads;
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = false;
|
||||
|
||||
whisper_reset_timings(context);
|
||||
|
||||
LOGI("About to run whisper_full");
|
||||
if (whisper_full(context, params, audio_data_arr, audio_data_length) != 0) {
|
||||
LOGI("Failed to run the model");
|
||||
} else {
|
||||
whisper_print_timings(context);
|
||||
}
|
||||
(*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentCount(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
return whisper_full_n_segments(context);
|
||||
}
|
||||
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegment(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const char *text = whisper_full_get_segment_text(context, index);
|
||||
jstring string = (*env)->NewStringUTF(env, text);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentT0(JNIEnv *env, jobject thiz,jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const int64_t t0 = whisper_full_get_segment_t0(context, index);
|
||||
return (jlong)t0;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentT1(JNIEnv *env, jobject thiz,jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const int64_t t1 = whisper_full_get_segment_t1(context, index);
|
||||
return (jlong)t1;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getSystemInfo(
|
||||
JNIEnv *env, jobject thiz
|
||||
) {
|
||||
UNUSED(thiz);
|
||||
const char *sysinfo = whisper_print_system_info();
|
||||
jstring string = (*env)->NewStringUTF(env, sysinfo);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_benchMemcpy(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_benchGgmlMulMat(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
|
||||
}
|
||||
|
@ -1,30 +0,0 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="85.84757"
|
||||
android:endY="92.4963"
|
||||
android:startX="42.9492"
|
||||
android:startY="49.59793"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||
android:strokeWidth="1"
|
||||
android:strokeColor="#00000000" />
|
||||
</vector>
|
@ -1,170 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
@ -1,57 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:orientation="vertical"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content">
|
||||
|
||||
<Button
|
||||
android:id="@+id/systemInfoBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="System Info" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/loadModelBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Load model" />
|
||||
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content">
|
||||
|
||||
<Button
|
||||
android:id="@+id/transcriptSampleBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Transcribe sample" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/clearBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Clear" />
|
||||
</LinearLayout>
|
||||
|
||||
<TextView
|
||||
android:id="@+id/sample_text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Hello World!"
|
||||
app:layout_constraintBottom_toBottomOf="parent"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintRight_toRightOf="parent"
|
||||
app:layout_constraintTop_toTopOf="parent"
|
||||
android:scrollbarAlwaysDrawHorizontalTrack="true"
|
||||
android:maxLines="999"/>
|
||||
|
||||
</LinearLayout>
|
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
Before Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 5.2 KiB |
Before Width: | Height: | Size: 2.6 KiB |
Before Width: | Height: | Size: 3.3 KiB |
Before Width: | Height: | Size: 4.8 KiB |
Before Width: | Height: | Size: 7.3 KiB |
Before Width: | Height: | Size: 7.7 KiB |
Before Width: | Height: | Size: 12 KiB |
Before Width: | Height: | Size: 10 KiB |
Before Width: | Height: | Size: 16 KiB |
@ -1,16 +0,0 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.Whisperandroidjava" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_200</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/black</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_200</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
@ -1,3 +0,0 @@
|
||||
<resources>
|
||||
<string name="app_name">whisper.android.java</string>
|
||||
</resources>
|
@ -1,16 +0,0 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.Whisperandroidjava" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_500</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/white</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_700</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
@ -1,17 +0,0 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
|
||||
*/
|
||||
public class ExampleUnitTest {
|
||||
@Test
|
||||
public void addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2);
|
||||
}
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
buildscript {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
dependencies {
|
||||
classpath "com.android.tools.build:gradle:4.1.3"
|
||||
|
||||
// NOTE: Do not place your application dependencies here; they belong
|
||||
// in the individual module build.gradle files
|
||||
}
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
task clean(type: Delete) {
|
||||
delete rootProject.buildDir
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app"s APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Automatically convert third-party libraries to use AndroidX
|
||||
android.enableJetifier=true
|
@ -1,6 +0,0 @@
|
||||
#Fri Oct 20 11:07:15 HST 2023
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-all.zip
|
172
examples/whisper.android.java/gradlew
vendored
@ -1,172 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS=""
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin, switch paths to Windows format before running java
|
||||
if $cygwin ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=$((i+1))
|
||||
done
|
||||
case $i in
|
||||
(0) set -- ;;
|
||||
(1) set -- "$args0" ;;
|
||||
(2) set -- "$args0" "$args1" ;;
|
||||
(3) set -- "$args0" "$args1" "$args2" ;;
|
||||
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=$(save "$@")
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
|
||||
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
|
||||
cd "$(dirname "$0")"
|
||||
fi
|
||||
|
||||
exec "$JAVACMD" "$@"
|
84
examples/whisper.android.java/gradlew.bat
vendored
@ -1,84 +0,0 @@
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS=
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:init
|
||||
@rem Get command-line arguments, handling Windows variants
|
||||
|
||||
if not "%OS%" == "Windows_NT" goto win9xME_args
|
||||
|
||||
:win9xME_args
|
||||
@rem Slurp the command line arguments.
|
||||
set CMD_LINE_ARGS=
|
||||
set _SKIP=2
|
||||
|
||||
:win9xME_args_slurp
|
||||
if "x%~1" == "x" goto execute
|
||||
|
||||
set CMD_LINE_ARGS=%*
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
@ -1,2 +0,0 @@
|
||||
include ':app'
|
||||
rootProject.name = "whisper.android.java"
|
2
examples/whisper.android/.idea/gradle.xml
generated
@ -4,7 +4,6 @@
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
|
||||
<option name="modules">
|
||||
@ -14,7 +13,6 @@
|
||||
</set>
|
||||
</option>
|
||||
<option name="resolveExternalAnnotations" value="false" />
|
||||
<option name="resolveModulePerSourceSet" value="false" />
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
|
@ -17,12 +17,12 @@ else
|
||||
encoder_only=$2
|
||||
fi
|
||||
|
||||
models=( \
|
||||
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
||||
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
||||
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
|
||||
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" "medium-dis" \
|
||||
"large-v2" "large-v2-q4_0" "large-v2-q4_1" "large-v2-q5_0" "large-v2-q5_1" "large-v2-q8_0" "large-v2-dis" \
|
||||
models=( \
|
||||
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
||||
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
||||
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
|
||||
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
|
||||
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
|
||||
)
|
||||
|
||||
if [ "$encoder_only" -eq 0 ]; then
|
||||
@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
|
||||
printf "\n"
|
||||
fi
|
||||
|
||||
printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
# actual run
|
||||
@ -56,7 +56,6 @@ for model in "${models[@]}"; do
|
||||
# parse the output:
|
||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
||||
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
||||
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
|
||||
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
||||
system_info=$(echo "$output" | grep "system_info")
|
||||
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
||||
@ -95,6 +94,6 @@ for model in "${models[@]}"; do
|
||||
commit=$(git rev-parse --short HEAD)
|
||||
|
||||
if [ $ret -eq 0 ]; then
|
||||
printf "| <todo> | <todo> | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
fi
|
||||
done
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
python3 models/convert-pt-to-ggml.py ~/.cache/whisper/$model.pt ../whisper models/
|
||||
|
317
ggml-cuda.cu
@ -39,6 +39,7 @@
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetMemPool hipDeviceGetMemPool
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
@ -48,6 +49,7 @@
|
||||
#define cudaEvent_t hipEvent_t
|
||||
#define cudaEventDestroy hipEventDestroy
|
||||
#define cudaFree hipFree
|
||||
#define cudaFreeAsync hipFreeAsync
|
||||
#define cudaFreeHost hipHostFree
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaGetDeviceCount hipGetDeviceCount
|
||||
@ -55,6 +57,7 @@
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
#define cudaMemcpy hipMemcpy
|
||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||
@ -63,6 +66,9 @@
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemPool_t hipMemPool_t
|
||||
#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
|
||||
#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
@ -88,8 +94,6 @@
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||
|
||||
#define GGML_CUDA_MAX_NODES 8192
|
||||
|
||||
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
||||
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
||||
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
||||
@ -184,11 +188,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
int dev_id; \
|
||||
cudaGetDevice(&dev_id); \
|
||||
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
fprintf(stderr, "current device: %d\n", dev_id); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
@ -198,11 +202,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
int dev_id; \
|
||||
cudaGetDevice(&dev_id); \
|
||||
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
||||
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
fprintf(stderr, "current device: %d\n", dev_id); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
@ -436,8 +440,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_RELU_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||
#define CUDA_CLAMP_BLOCK_SIZE 256
|
||||
@ -470,6 +472,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
||||
|
||||
#define MAX_STREAMS 8
|
||||
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
|
||||
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
||||
|
||||
struct ggml_tensor_extra_gpu {
|
||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||
@ -558,24 +561,6 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = fmaxf(x[i], 0);
|
||||
}
|
||||
|
||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -1005,7 +990,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -1109,7 +1094,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -1213,7 +1198,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
@ -1467,7 +1452,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -4277,7 +4262,7 @@ template <bool need_check> static __global__ void
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
@ -4317,7 +4302,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
@ -4756,7 +4741,7 @@ static __global__ void im2col_f32_f16(
|
||||
int ofs0, int ofs1, int IW, int IH, int CHW,
|
||||
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
|
||||
const int offset_dst =
|
||||
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||
@ -4808,16 +4793,6 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
||||
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
@ -4926,8 +4901,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4936,7 +4910,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4945,7 +4919,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4954,7 +4928,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4963,7 +4937,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4973,7 +4947,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -4982,7 +4956,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -4991,7 +4965,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -5006,7 +4980,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -5014,7 +4988,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5023,7 +4997,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK4_1 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5032,7 +5006,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK5_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5041,7 +5015,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK5_1 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5050,7 +5024,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK8_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5059,7 +5033,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5068,7 +5042,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5077,7 +5051,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5086,7 +5060,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5095,7 +5069,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5114,7 +5088,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
|
||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -5851,6 +5825,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
|
||||
if (g_cudaMemPools[id] == nullptr) {
|
||||
return ggml_cuda_pool_malloc(size, actual_size);
|
||||
}
|
||||
void *ptr;
|
||||
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
|
||||
*actual_size = size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
int id;
|
||||
@ -5868,10 +5852,12 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
|
||||
static bool g_cublas_loaded = false;
|
||||
|
||||
bool ggml_cublas_loaded(void) {
|
||||
return g_cublas_loaded;
|
||||
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
|
||||
if (g_cudaMemPools[id] == nullptr) {
|
||||
return ggml_cuda_pool_free(ptr, actual_size);
|
||||
}
|
||||
CUDA_CHECK(cudaFreeAsync(ptr, stream));
|
||||
}
|
||||
|
||||
void ggml_init_cublas() {
|
||||
@ -5886,12 +5872,7 @@ void ggml_init_cublas() {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
#endif
|
||||
|
||||
if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
|
||||
initialized = true;
|
||||
g_cublas_loaded = false;
|
||||
return;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
|
||||
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
int64_t total_vram = 0;
|
||||
#if defined(GGML_CUDA_FORCE_MMQ)
|
||||
@ -5933,13 +5914,19 @@ void ggml_init_cublas() {
|
||||
// create cublas handle
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
|
||||
// configure memory pool
|
||||
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
||||
if (err == cudaSuccess) {
|
||||
size_t treshold = UINT64_MAX;
|
||||
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
||||
}
|
||||
}
|
||||
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||
|
||||
initialized = true;
|
||||
g_cublas_loaded = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -6206,34 +6193,6 @@ inline void ggml_cuda_op_silu(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_relu(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_sqr(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_norm(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -6555,7 +6514,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
size_t ne = row_diff*ne00;
|
||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
|
||||
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
||||
}
|
||||
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
||||
@ -6566,12 +6525,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
size_t ne = src1_ncols*ne10;
|
||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
||||
size_t dst_f16_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
@ -6589,14 +6548,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
||||
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||
if (dst_f16_as != 0) {
|
||||
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
|
||||
}
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
|
||||
}
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
|
||||
}
|
||||
}
|
||||
else {
|
||||
@ -6606,7 +6566,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
|
||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
||||
}
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
||||
@ -6623,7 +6583,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
&beta, dst_dd_i, ldc));
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
|
||||
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -7048,8 +7008,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
int64_t row_low[GGML_CUDA_MAX_DEVICES];
|
||||
int64_t row_high[GGML_CUDA_MAX_DEVICES];
|
||||
|
||||
int used_devices = 0;
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
// by default, use all rows
|
||||
row_low[id] = 0;
|
||||
@ -7077,8 +7035,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
continue;
|
||||
}
|
||||
|
||||
used_devices++;
|
||||
|
||||
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
|
||||
|
||||
@ -7089,21 +7045,22 @@ static void ggml_cuda_op_mul_mat(
|
||||
src0_dd[id] = (char *) src0_extra->data_device[id];
|
||||
} else {
|
||||
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
||||
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
|
||||
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
|
||||
}
|
||||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
||||
} else {
|
||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
|
||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
|
||||
}
|
||||
|
||||
if (convert_src1_to_q8_1) {
|
||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
|
||||
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
|
||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
|
||||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
// CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
@ -7111,18 +7068,18 @@ static void ggml_cuda_op_mul_mat(
|
||||
dst_dd[id] = (float *) dst_extra->data_device[id];
|
||||
} else {
|
||||
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
||||
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
|
||||
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
|
||||
}
|
||||
}
|
||||
|
||||
// if multiple devices are used they need to wait for the main device
|
||||
// here an event is recorded that signals that the main device has finished calculating the input data
|
||||
if (split && used_devices > 1) {
|
||||
if (split && g_device_count > 1) {
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
|
||||
}
|
||||
|
||||
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
||||
const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
||||
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
||||
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
|
||||
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
||||
@ -7237,27 +7194,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
|
||||
continue;
|
||||
}
|
||||
CUDA_CHECK(ggml_cuda_set_device(id));
|
||||
|
||||
// free buffers again when done
|
||||
if (src0_as[id] > 0) {
|
||||
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
|
||||
}
|
||||
if (src1_asf[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
|
||||
}
|
||||
if (src1_asq[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
|
||||
}
|
||||
if (dst_as[id] > 0) {
|
||||
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
|
||||
}
|
||||
}
|
||||
|
||||
// main device waits for all other devices to be finished
|
||||
if (split && g_device_count > 1) {
|
||||
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
||||
@ -7265,9 +7201,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (row_low[id] == row_high[id]) {
|
||||
continue;
|
||||
}
|
||||
for (int64_t is = 0; is < is_max; ++is) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
||||
}
|
||||
@ -7278,6 +7211,21 @@ static void ggml_cuda_op_mul_mat(
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (src0_as[id] > 0) {
|
||||
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
if (src1_asf[id] > 0) {
|
||||
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
if (src1_asq[id] > 0) {
|
||||
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
if (dst_as[id] > 0) {
|
||||
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@ -7304,14 +7252,6 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
||||
}
|
||||
|
||||
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
||||
}
|
||||
@ -7321,8 +7261,6 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
|
||||
}
|
||||
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
if (!g_cublas_loaded) return false;
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
@ -7474,11 +7412,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
|
||||
size_t src1_as = 0;
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
|
||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
@ -7532,8 +7470,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
size_t ptrs_src_s = 0;
|
||||
size_t ptrs_dst_s = 0;
|
||||
|
||||
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
||||
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
|
||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
@ -7546,7 +7484,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
dst->nb[2], dst->nb[3],
|
||||
r2, r3);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
@ -7558,29 +7495,30 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
if (ptrs_src_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
||||
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
|
||||
}
|
||||
if (ptrs_dst_s != 0) {
|
||||
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
||||
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
|
||||
}
|
||||
if (dst_as != 0) {
|
||||
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool all_on_device =
|
||||
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
||||
(src0->backend == GGML_BACKEND_GPU) &&
|
||||
(src1->backend == GGML_BACKEND_GPU) &&
|
||||
( dst->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
||||
@ -7602,13 +7540,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
} else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
} else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
@ -7729,7 +7667,7 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
||||
}
|
||||
|
||||
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
||||
}
|
||||
|
||||
@ -7844,11 +7782,11 @@ static size_t g_temp_tensor_extra_index = 0;
|
||||
|
||||
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
if (g_temp_tensor_extras == nullptr) {
|
||||
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
|
||||
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
|
||||
}
|
||||
|
||||
size_t alloc_index = g_temp_tensor_extra_index;
|
||||
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
|
||||
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
|
||||
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
@ -8015,8 +7953,6 @@ void ggml_cuda_free_scratch() {
|
||||
}
|
||||
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
if (!g_cublas_loaded) return false;
|
||||
|
||||
ggml_cuda_func_t func;
|
||||
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
||||
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||
@ -8026,15 +7962,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT) {
|
||||
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||
#ifndef NDEBUG
|
||||
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_REPEAT:
|
||||
func = ggml_cuda_repeat;
|
||||
@ -8059,9 +7986,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_UNARY_OP_SILU:
|
||||
func = ggml_cuda_silu;
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
func = ggml_cuda_relu;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
} break;
|
||||
@ -8080,9 +8004,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_SCALE:
|
||||
func = ggml_cuda_scale;
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
func = ggml_cuda_sqr;
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
@ -8175,11 +8096,11 @@ struct ggml_backend_buffer_context_cuda {
|
||||
|
||||
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
if (temp_tensor_extras == nullptr) {
|
||||
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
|
||||
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
|
||||
}
|
||||
|
||||
size_t alloc_index = temp_tensor_extra_index;
|
||||
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
|
||||
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
|
||||
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
|
@ -17,12 +17,7 @@ extern "C" {
|
||||
|
||||
#define GGML_CUDA_MAX_DEVICES 16
|
||||
|
||||
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
|
||||
GGML_API void ggml_init_cublas(void);
|
||||
|
||||
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
|
||||
GGML_API bool ggml_cublas_loaded(void);
|
||||
|
||||
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
||||
GGML_API void ggml_cuda_host_free(void * ptr);
|
||||
|
||||
|
27
ggml-metal.m
@ -128,7 +128,7 @@ struct ggml_metal_context {
|
||||
// MSL code
|
||||
// TODO: move the contents here when ready
|
||||
// for now it is easier to work in a separate file
|
||||
//static NSString * const msl_library_source = @"see metal.metal";
|
||||
static NSString * const msl_library_source = @"see metal.metal";
|
||||
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
@ -144,8 +144,7 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
|
||||
ggml_metal_log_user_data = user_data;
|
||||
}
|
||||
|
||||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
||||
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
||||
if (ggml_metal_log_callback != NULL) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
@ -340,15 +339,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
||||
if ([ctx->device supportsFamily:i]) {
|
||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
||||
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
if (ctx->device.maxTransferRate != 0) {
|
||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||
} else {
|
||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
||||
}
|
||||
@ -541,11 +540,11 @@ bool ggml_metal_add_buffer(
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6);
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6);
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
|
||||
++ctx->n_buffers;
|
||||
} else {
|
||||
@ -565,11 +564,11 @@ bool ggml_metal_add_buffer(
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6);
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i);
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
||||
if (i + size_step < size) {
|
||||
GGML_METAL_LOG_INFO("\n");
|
||||
}
|
||||
@ -580,16 +579,16 @@ bool ggml_metal_add_buffer(
|
||||
|
||||
#if TARGET_OS_OSX
|
||||
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
||||
ctx->device.currentAllocatedSize / 1e6,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
||||
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
|
||||
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
||||
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||
} else {
|
||||
GGML_METAL_LOG_INFO("\n");
|
||||
}
|
||||
#else
|
||||
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6);
|
||||
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1,57 +0,0 @@
|
||||
# - "turn on lights."
|
||||
# - "set thermostat to 22."
|
||||
# - "increase TV by 10."
|
||||
# - "decrease oven by 50."
|
||||
# - "play music."
|
||||
# - "stop podcast."
|
||||
# - "schedule cleaning at 3pm."
|
||||
# - "cancel cleaning."
|
||||
# - "remind me to buy milk at 5pm."
|
||||
# - "show me security system."
|
||||
# - "hide washing machine."
|
||||
# - "what is the lights status?"
|
||||
# - "what is the current thermostat value?"
|
||||
# - "what is the security system status?"
|
||||
# - "what is the door lock status?"
|
||||
# - "what is the camera battery level?"
|
||||
# - "what is the weather like today?"
|
||||
# - "what is the forecast for tomorrow?"
|
||||
# - "what is the time?"
|
||||
# - "what is my schedule for today?"
|
||||
# - "what tasks do I have?"
|
||||
# - "what reminders do I have?"
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
|
||||
#
|
||||
|
||||
root ::= init " " (command | question) "."
|
||||
prompt ::= init
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " Ok Whisper, start listening for commands."
|
||||
|
||||
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
|
||||
"Increase " device " by " value | "Decrease " device " by " value |
|
||||
"Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
|
||||
"Remind me to " task " at " time | "Show me " device | "Hide " device
|
||||
|
||||
question ::= "What is the " device " status?" | "What is the current " device " value?" |
|
||||
"What is the " device " temperature?" | "What is the " device " humidity?" |
|
||||
"What is the " device " power consumption?" | "What is the " device " battery level?" |
|
||||
"What is the weather like today?" | "What is the forecast for tomorrow?" |
|
||||
"What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
|
||||
"What reminders do I have?"
|
||||
|
||||
device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
|
||||
"music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
|
||||
"vacuum cleaner"
|
||||
|
||||
value ::= [0-9]+
|
||||
|
||||
media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
|
||||
|
||||
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
|
||||
|
||||
time ::= [0-9] [0-9]? ("am" | "pm")?
|
@ -1,29 +0,0 @@
|
||||
# - bishop to c3
|
||||
# - rook to d4
|
||||
# - knight to e5
|
||||
# - d4 d5 knight to c3
|
||||
# - c3 queen to d4 king b1
|
||||
# - pawn to a1 bishop to b2 knight to c3
|
||||
#
|
||||
# The prompt (--prompt) is the initial phrase that the user has to say.
|
||||
# This is used to prime Whisper with how the user is expected to speak.
|
||||
#
|
||||
# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
|
||||
# Longer context is better, but it slightly increases the processing time.
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
|
||||
#
|
||||
|
||||
root ::= init move move? move? "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " rook to b4, f3"
|
||||
|
||||
move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
||||
|
||||
piece ::= "bishop" | "rook" | "knight" | "queen"
|
||||
king ::= "king"
|
||||
pawn ::= "pawn"
|
@ -1,16 +0,0 @@
|
||||
# - red
|
||||
# - green
|
||||
# - blue
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
|
||||
#
|
||||
|
||||
root ::= init color "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " red, green, blue"
|
||||
|
||||
color ::= ", " ("red" | "green" | "blue")
|
@ -39,19 +39,19 @@ https://huggingface.co/ggerganov/whisper.cpp/tree/main
|
||||
|
||||
## Available models
|
||||
|
||||
| Model | Disk | SHA |
|
||||
| --- | --- | --- |
|
||||
| tiny | 75 MiB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| tiny.en | 75 MiB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
|
||||
| base | 142 MiB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| base.en | 142 MiB | `137c40403d78fd54d454da0f9bd998f78703390c` |
|
||||
| small | 466 MiB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| small.en | 466 MiB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
|
||||
| medium | 1.5 GiB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| medium.en | 1.5 GiB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
|
||||
| large-v1 | 2.9 GiB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
|
||||
| large-v2 | 2.9 GiB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
| large-v3 | 2.9 GiB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| tiny.en | 75 MB | ~390 MB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
|
||||
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| base.en | 142 MB | ~500 MB | `137c40403d78fd54d454da0f9bd998f78703390c` |
|
||||
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| small.en | 466 MB | ~1.0 GB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
|
||||
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| medium.en | 1.5 GB | ~2.6 GB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
|
||||
| large-v1 | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
|
||||
| large-v2 | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
|
||||
| large | 2.9 GB | ~4.7 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
|
||||
|
||||
## Model files for testing purposes
|
||||
|
||||
@ -76,27 +76,3 @@ git clone https://huggingface.co/openai/whisper-medium
|
||||
# convert the model to ggml
|
||||
python3 ./whisper.cpp/models/convert-h5-to-ggml.py ./whisper-medium/ ./whisper .
|
||||
```
|
||||
|
||||
## Distilled models
|
||||
|
||||
Initial support for https://huggingface.co/distil-whisper is available.
|
||||
|
||||
Currently, the chunk-based transcription strategy is not implemented, so there can be sub-optimal quality when using the distilled models with `whisper.cpp`.
|
||||
|
||||
```bash
|
||||
# clone OpenAI whisper and whisper.cpp
|
||||
git clone https://github.com/openai/whisper
|
||||
git clone https://github.com/ggerganov/whisper.cpp
|
||||
|
||||
# get the models
|
||||
cd whisper.cpp/models
|
||||
git clone https://huggingface.co/distil-whisper/distil-medium.en
|
||||
git clone https://huggingface.co/distil-whisper/distil-large-v2
|
||||
|
||||
# convert to ggml
|
||||
python3 ./convert-h5-to-ggml.py ./distil-medium.en/ ../../whisper .
|
||||
mv ggml-model.bin ggml-medium.en-distil.bin
|
||||
|
||||
python3 ./convert-h5-to-ggml.py ./distil-large-v2/ ../../whisper .
|
||||
mv ggml-model.bin ggml-large-v2-distil.bin
|
||||
```
|
||||
|
@ -78,14 +78,14 @@ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
||||
# Ported from models/convert-whisper-to-coreml.py
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
|
||||
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
|
||||
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
|
||||
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
||||
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
||||
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
|
||||
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
pt_target_path = f"models/hf-{args.model_name}.pt"
|
||||
|
@ -296,13 +296,13 @@ def convert_decoder(hparams, model, quantize=False):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
|
||||
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
|
||||
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
|
||||
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
|
||||
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
whisper = load_model(args.model).cpu()
|
||||
|
@ -38,10 +38,10 @@ def convert_encoder(hparams, encoder, mname):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
|
||||
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
whisper = load_model(args.model).cpu()
|
||||
|
@ -19,7 +19,7 @@ function get_script_path() {
|
||||
models_path="$(get_script_path)"
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
|
@ -8,7 +8,7 @@ popd
|
||||
set argc=0
|
||||
for %%x in (%*) do set /A argc+=1
|
||||
|
||||
set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3
|
||||
set models=tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large
|
||||
|
||||
if %argc% neq 1 (
|
||||
echo.
|
||||
|
@ -22,7 +22,7 @@ function get_script_path() {
|
||||
models_path="$(get_script_path)"
|
||||
|
||||
# Whisper models
|
||||
models=(
|
||||
models=(
|
||||
"tiny.en"
|
||||
"tiny"
|
||||
"tiny-q5_1"
|
||||
@ -42,7 +42,7 @@ models=(
|
||||
"medium.en-q5_0"
|
||||
"large-v1"
|
||||
"large-v2"
|
||||
"large-v3"
|
||||
"large"
|
||||
"large-q5_0"
|
||||
)
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
cd `dirname $0`
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
|
1698
whisper.cpp
41
whisper.h
@ -78,9 +78,7 @@ extern "C" {
|
||||
struct whisper_state;
|
||||
struct whisper_full_params;
|
||||
|
||||
typedef int32_t whisper_pos;
|
||||
typedef int32_t whisper_token;
|
||||
typedef int32_t whisper_seq_id;
|
||||
typedef int whisper_token;
|
||||
|
||||
struct whisper_context_params {
|
||||
bool use_gpu;
|
||||
@ -111,37 +109,6 @@ extern "C" {
|
||||
void (*close)(void * ctx);
|
||||
} whisper_model_loader;
|
||||
|
||||
// grammar element type
|
||||
enum whisper_gretype {
|
||||
// end of rule definition
|
||||
WHISPER_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
WHISPER_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
WHISPER_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
WHISPER_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
WHISPER_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or
|
||||
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
WHISPER_GRETYPE_CHAR_ALT = 6,
|
||||
};
|
||||
|
||||
typedef struct whisper_grammar_element {
|
||||
enum whisper_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} whisper_grammar_element;
|
||||
|
||||
// Various functions for loading a ggml whisper model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
@ -435,7 +402,6 @@ extern "C" {
|
||||
|
||||
bool translate;
|
||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
bool no_timestamps; // do not generate timestamps
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
@ -513,11 +479,6 @@ extern "C" {
|
||||
// called by each decoder to filter obtained logits
|
||||
whisper_logits_filter_callback logits_filter_callback;
|
||||
void * logits_filter_callback_user_data;
|
||||
|
||||
const whisper_grammar_element ** grammar_rules;
|
||||
size_t n_grammar_rules;
|
||||
size_t i_start_rule;
|
||||
float grammar_penalty;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
|
||||
|