Compare commits
28 Commits
Author | SHA1 | Date | |
---|---|---|---|
270b1e48db | |||
b7c82a37b1 | |||
4c245ea108 | |||
820f45895e | |||
6c8a003acf | |||
ae1bd69041 | |||
3ed9af34f2 | |||
9006946e4b | |||
d77603578b | |||
b2123cb463 | |||
3d24e35f49 | |||
91096daa1a | |||
8b943f9843 | |||
3cbaaed060 | |||
d4231649e6 | |||
3e5c7feeff | |||
c23598e4ca | |||
54a08bde29 | |||
9f8bbd3fee | |||
3172006a24 | |||
684bc8bd70 | |||
b0502836b8 | |||
ec7a6f04f9 | |||
37947203e6 | |||
953419c69a | |||
0de8582f65 | |||
baeb733691 | |||
d03c60dd7f |
26
.github/workflows/build.yml
vendored
@ -396,6 +396,32 @@ jobs:
|
||||
cd examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
||||
|
||||
android_java:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set up JDK 11
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'temurin'
|
||||
cache: gradle
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
with:
|
||||
api-level: 30
|
||||
build-tools-version: 30.0.3
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android.java
|
||||
chmod +x ./gradlew
|
||||
./gradlew assembleRelease
|
||||
|
||||
java:
|
||||
needs: [ 'windows' ]
|
||||
runs-on: windows-latest
|
||||
|
10
.gitignore
vendored
@ -8,6 +8,7 @@
|
||||
.DS_Store
|
||||
|
||||
build/
|
||||
build-coreml/
|
||||
build-em/
|
||||
build-debug/
|
||||
build-release/
|
||||
@ -18,6 +19,11 @@ build-no-accel/
|
||||
build-sanitize-addr/
|
||||
build-sanitize-thread/
|
||||
|
||||
# SPM
|
||||
.build/
|
||||
.swiftpm
|
||||
*.metallib
|
||||
|
||||
/main
|
||||
/stream
|
||||
/command
|
||||
@ -48,3 +54,7 @@ bindings/java/.idea/
|
||||
.idea/
|
||||
|
||||
benchmark_results.csv
|
||||
cmake-build-debug/
|
||||
.cxx/
|
||||
.gradle/
|
||||
local.properties
|
42
Makefile
@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
|
||||
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
|
||||
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
|
||||
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
@ -331,11 +331,11 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
|
||||
WHISPER_OBJ += ggml-metal.o
|
||||
endif
|
||||
|
||||
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
||||
libwhisper.a: $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a $(WHISPER_OBJ)
|
||||
|
||||
libwhisper.so: ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
|
||||
libwhisper.so: $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
|
||||
@ -349,30 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
|
||||
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
|
||||
SRC_COMMON_SDL = examples/common-sdl.cpp
|
||||
|
||||
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
|
||||
main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
|
||||
./main -h
|
||||
|
||||
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
||||
bench: examples/bench/bench.cpp $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
||||
|
||||
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
#
|
||||
# Audio samples
|
||||
|
77
Package.swift
Normal file
@ -0,0 +1,77 @@
|
||||
// swift-tools-version:5.5
|
||||
|
||||
import PackageDescription
|
||||
|
||||
#if arch(arm) || arch(arm64)
|
||||
let platforms: [SupportedPlatform]? = [
|
||||
.macOS(.v12),
|
||||
.iOS(.v14),
|
||||
.watchOS(.v4),
|
||||
.tvOS(.v14)
|
||||
]
|
||||
let exclude: [String] = []
|
||||
let resources: [Resource] = [
|
||||
.process("ggml-metal.metal")
|
||||
]
|
||||
let additionalSources: [String] = ["ggml-metal.m"]
|
||||
let additionalSettings: [CSetting] = [
|
||||
.unsafeFlags(["-fno-objc-arc"]),
|
||||
.define("GGML_USE_METAL")
|
||||
]
|
||||
#else
|
||||
let platforms: [SupportedPlatform]? = nil
|
||||
let exclude: [String] = ["ggml-metal.metal"]
|
||||
let resources: [Resource] = []
|
||||
let additionalSources: [String] = []
|
||||
let additionalSettings: [CSetting] = []
|
||||
#endif
|
||||
|
||||
let package = Package(
|
||||
name: "whisper",
|
||||
platforms: platforms,
|
||||
products: [
|
||||
.library(name: "whisper", targets: ["whisper"]),
|
||||
],
|
||||
targets: [
|
||||
.target(
|
||||
name: "whisper",
|
||||
path: ".",
|
||||
exclude: exclude + [
|
||||
"bindings",
|
||||
"cmake",
|
||||
"coreml",
|
||||
"examples",
|
||||
"extra",
|
||||
"models",
|
||||
"samples",
|
||||
"tests",
|
||||
"CMakeLists.txt",
|
||||
"ggml-cuda.cu",
|
||||
"ggml-cuda.h",
|
||||
"Makefile"
|
||||
],
|
||||
sources: [
|
||||
"ggml.c",
|
||||
"whisper.cpp",
|
||||
"ggml-alloc.c",
|
||||
"ggml-backend.c",
|
||||
"ggml-quants.c"
|
||||
] + additionalSources,
|
||||
resources: resources,
|
||||
publicHeadersPath: "spm-headers",
|
||||
cSettings: [
|
||||
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
|
||||
.define("GGML_USE_ACCELERATE")
|
||||
// NOTE: NEW_LAPACK will required iOS version 16.4+
|
||||
// We should consider add this in the future when we drop support for iOS 14
|
||||
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
|
||||
// .define("ACCELERATE_NEW_LAPACK"),
|
||||
// .define("ACCELERATE_LAPACK_ILP64")
|
||||
] + additionalSettings,
|
||||
linkerSettings: [
|
||||
.linkedFramework("Accelerate")
|
||||
]
|
||||
)
|
||||
],
|
||||
cxxLanguageStandard: .cxx11
|
||||
)
|
16
README.md
@ -16,12 +16,10 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Support for CPU-only inference
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Efficient GPU support for NVIDIA](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
||||
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
|
||||
@ -36,10 +34,8 @@ Supported platforms:
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
|
||||
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
|
||||
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
|
||||
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
|
||||
|
||||
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
||||
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
||||
@ -400,12 +396,12 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
|
||||
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
||||
cached for the next run.
|
||||
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
||||
|
||||
## NVIDIA GPU support via cuBLAS
|
||||
## NVIDIA GPU support
|
||||
|
||||
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
|
||||
With NVIDIA cards the processing of the models is done efficiently on the GPU via cuBLAS and custom CUDA kernels.
|
||||
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
Now build `whisper.cpp` with cuBLAS support:
|
||||
|
@ -9,6 +9,7 @@ archivesBaseName = 'whispercpp'
|
||||
group = 'io.github.ggerganov'
|
||||
version = '1.4.0'
|
||||
|
||||
|
||||
sourceCompatibility = 1.8
|
||||
targetCompatibility = 1.8
|
||||
|
||||
|
@ -2,6 +2,7 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -9,6 +10,8 @@ import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
|
||||
@ -160,6 +163,28 @@ public class WhisperCpp implements AutoCloseable {
|
||||
|
||||
return str.toString().trim();
|
||||
}
|
||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
int nSegments = lib.whisper_full_n_segments(ctx);
|
||||
List<WhisperSegment> segments= new ArrayList<>(nSegments);
|
||||
|
||||
|
||||
for (int i = 0; i < nSegments; i++) {
|
||||
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
|
||||
String text = lib.whisper_full_get_segment_text(ctx, i);
|
||||
long t1 = lib.whisper_full_get_segment_t1(ctx, i);
|
||||
segments.add(new WhisperSegment(t0,t1,text));
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
|
||||
// public int getTextSegmentCount(Pointer ctx) {
|
||||
// return lib.whisper_full_n_segments(ctx);
|
||||
|
@ -0,0 +1,47 @@
|
||||
package io.github.ggerganov.whispercpp.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "[" + start + " --> " + end + "]:" + sentence;
|
||||
}
|
||||
}
|
@ -58,6 +58,9 @@ public class WhisperFullParams extends Structure {
|
||||
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
||||
}
|
||||
|
||||
/** Generate timestamps or not? */
|
||||
public CBool no_timestamps;
|
||||
|
||||
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||
public CBool single_segment;
|
||||
|
||||
@ -304,10 +307,16 @@ public class WhisperFullParams extends Structure {
|
||||
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
/** Grammar stuff */
|
||||
public Pointer grammar_rules;
|
||||
public long n_grammar_rules;
|
||||
public long i_start_rule;
|
||||
public float grammar_penalty;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||
"no_context", "single_segment",
|
||||
"no_context", "single_segment", "no_timestamps",
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
@ -316,6 +325,7 @@ public class WhisperFullParams extends Structure {
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||
"logits_filter_callback", "logits_filter_callback_user_data");
|
||||
"logits_filter_callback", "logits_filter_callback_user_data",
|
||||
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.CBool;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -11,6 +12,7 @@ import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.util.List;
|
||||
|
||||
class WhisperCppTest {
|
||||
private static WhisperCpp whisper = new WhisperCpp();
|
||||
@ -20,11 +22,12 @@ class WhisperCppTest {
|
||||
static void init() throws FileNotFoundException {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
//String modelName = "../../models/ggml-tiny.bin";
|
||||
String modelName = "../../models/ggml-tiny.en.bin";
|
||||
try {
|
||||
whisper.initContext(modelName);
|
||||
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
modelInitialised = true;
|
||||
} catch (FileNotFoundException ex) {
|
||||
System.out.println("Model " + modelName + " not found");
|
||||
@ -72,11 +75,11 @@ class WhisperCppTest {
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
// params.initial_prompt = "and so my fellow Americans um, like";
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
|
||||
try {
|
||||
@ -99,4 +102,43 @@ class WhisperCppTest {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFullTranscribeWithTime() throws Exception {
|
||||
if (!modelInitialised) {
|
||||
System.out.println("Model not initialised, skipping test");
|
||||
return;
|
||||
}
|
||||
|
||||
// Given
|
||||
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
|
||||
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
|
||||
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
try {
|
||||
audioInputStream.read(b);
|
||||
|
||||
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
|
||||
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
|
||||
floats[j] = intSample / 32767.0f;
|
||||
}
|
||||
|
||||
List<WhisperSegment> segments = whisper.fullTranscribeWithTime(params, floats);
|
||||
assertTrue(segments.size() > 0, "The size of segments should be greater than 0");
|
||||
for (WhisperSegment segment : segments) {
|
||||
System.out.println(segment);
|
||||
}
|
||||
} finally {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
|
||||
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
|
@ -3,6 +3,8 @@
|
||||
// Code is derived from the work of Github user @wangchou
|
||||
// ref: https://github.com/wangchou/callCoreMLFromCpp
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
int64_t n_ctx,
|
||||
int64_t n_mel,
|
||||
float * mel,
|
||||
float * out);
|
||||
|
||||
|
@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
int64_t n_ctx,
|
||||
int64_t n_mel,
|
||||
float * mel,
|
||||
float * out) {
|
||||
MLMultiArray * inMultiArray = [
|
||||
[MLMultiArray alloc] initWithDataPointer: mel
|
||||
shape: @[@1, @80, @3000]
|
||||
shape: @[@1, @(n_mel), @(n_ctx)]
|
||||
dataType: MLMultiArrayDataTypeFloat32
|
||||
strides: @[@(240000), @(3000), @1]
|
||||
strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
|
||||
deallocator: nil
|
||||
error: nil
|
||||
];
|
||||
|
@ -23,6 +23,7 @@ add_library(${TARGET} STATIC
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
grammar-parser.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// text-generation
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
// batched decoding
|
||||
for (int i = 0; i < 64; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// prompt processing
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
@ -21,6 +22,11 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
bool file_exists(const std::string & fname) {
|
||||
std::ifstream f(fname.c_str());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -30,8 +36,12 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
@ -45,6 +55,8 @@ struct whisper_params {
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string context;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -75,6 +87,9 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -109,16 +124,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string & grammar_rule,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
logprob_min = 0.0f;
|
||||
logprob_sum = 0.0f;
|
||||
n_tokens = 0;
|
||||
t_ms = 0;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
@ -126,19 +155,41 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 5;
|
||||
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
||||
if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
|
||||
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
|
||||
} else {
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
}
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
@ -147,19 +198,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
|
||||
result += text;
|
||||
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
if(token.plog > 0.0f) exit(0);
|
||||
logprob_min = std::min(logprob_min, token.plog);
|
||||
logprob_sum += token.plog;
|
||||
++n_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
@ -250,7 +299,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
@ -418,7 +467,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
int n_tokens = 0;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
|
||||
@ -456,7 +507,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
@ -492,18 +543,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
float logprob_min0 = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum0 = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens0 = 0;
|
||||
int n_tokens = 0;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
if (!params.prompt.empty()) {
|
||||
k_prompt = params.prompt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@ -536,9 +596,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
const float p = 100.0f * std::exp(logprob_min0);
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
@ -559,19 +621,30 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= txt.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
@ -584,9 +657,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
}
|
||||
}
|
||||
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
if (best_len == 0) {
|
||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||
} else {
|
||||
// cut the prompt from the decoded text
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
@ -654,12 +734,36 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
if (!params.grammar.empty()) {
|
||||
auto & grammar = params.grammar_parsed;
|
||||
if (file_exists(params.grammar.c_str())) {
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
} else {
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
@ -181,7 +181,7 @@ private:
|
||||
// It is assumed that PCM data is normalized to a range from -1 to 1
|
||||
bool write_audio(const float * data, size_t length) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
|
||||
const int16_t intSample = data[i] * 32767;
|
||||
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
||||
dataSize += sizeof(int16_t);
|
||||
}
|
||||
|
423
examples/grammar-parser.cpp
Normal file
@ -0,0 +1,423 @@
|
||||
#include "grammar-parser.h"
|
||||
#include <cstdint>
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
|
||||
namespace grammar_parser {
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
// copied from whisper.cpp
|
||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||
uint8_t highbits = first_byte >> 4;
|
||||
int len = lookup[highbits];
|
||||
uint8_t mask = (1 << (8 - len)) - 1;
|
||||
uint32_t value = first_byte & mask;
|
||||
const char * end = src + len; // may overrun!
|
||||
const char * pos = src + 1;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||
return next_id;
|
||||
}
|
||||
|
||||
void add_rule(
|
||||
parse_state & state,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule) {
|
||||
if (state.rules.size() <= rule_id) {
|
||||
state.rules.resize(rule_id + 1);
|
||||
}
|
||||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
const char * pos = src;
|
||||
const char * end = src + size;
|
||||
uint32_t value = 0;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value <<= 4;
|
||||
char c = *pos;
|
||||
if ('a' <= c && c <= 'f') {
|
||||
value += c - 'a' + 10;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
value += c - 'A' + 10;
|
||||
} else if ('0' <= c && c <= '9') {
|
||||
value += c - '0';
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos != end) {
|
||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
const char * parse_space(const char * src, bool newline_ok) {
|
||||
const char * pos = src;
|
||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||
if (*pos == '#') {
|
||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||
pos++;
|
||||
}
|
||||
} else {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_name(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_word_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
case 'x': return parse_hex(src + 2, 2);
|
||||
case 'u': return parse_hex(src + 2, 4);
|
||||
case 'U': return parse_hex(src + 2, 8);
|
||||
case 't': return std::make_pair('\t', src + 2);
|
||||
case 'r': return std::make_pair('\r', src + 2);
|
||||
case 'n': return std::make_pair('\n', src + 2);
|
||||
case '\\':
|
||||
case '"':
|
||||
case '[':
|
||||
case ']':
|
||||
return std::make_pair(src[1], src + 2);
|
||||
default:
|
||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||
}
|
||||
} else if (*src) {
|
||||
return decode_utf8(src);
|
||||
}
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_sequence(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
std::vector<whisper_grammar_element> & out_elements,
|
||||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != '"') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '[') { // char range(s)
|
||||
pos++;
|
||||
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
|
||||
if (*pos == '^') {
|
||||
pos++;
|
||||
start_type = WHISPER_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != ']') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
enum whisper_gretype type = last_sym_start < out_elements.size()
|
||||
? WHISPER_GRETYPE_CHAR_ALT
|
||||
: start_type;
|
||||
|
||||
out_elements.push_back({type, char_pair.first});
|
||||
if (pos[0] == '-' && pos[1] != ']') {
|
||||
auto endchar_pair = parse_char(pos + 1);
|
||||
pos = endchar_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||
last_sym_start = out_elements.size();
|
||||
// output reference to synthesized rule
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
if (*pos != ')') {
|
||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// rewrite rules:
|
||||
// S* --> S' ::= S S' |
|
||||
// S+ --> S' ::= S S' | S
|
||||
// S? --> S' ::= S |
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<whisper_grammar_element> sub_rule;
|
||||
// add preceding symbol to generated rule
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (*pos == '*' || *pos == '+') {
|
||||
// cause generated rule to recurse
|
||||
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
}
|
||||
// mark start of alternate def
|
||||
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
if (*pos == '+') {
|
||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
}
|
||||
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested) {
|
||||
std::vector<whisper_grammar_element> rule;
|
||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||
while (*pos == '|') {
|
||||
rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
pos = parse_space(pos + 1, true);
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||
}
|
||||
rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, rule_id, rule);
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_rule(parse_state & state, const char * src) {
|
||||
const char * name_end = parse_name(src);
|
||||
const char * pos = parse_space(name_end, false);
|
||||
size_t name_len = name_end - src;
|
||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||
const std::string name(src, name_len);
|
||||
|
||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 3, true);
|
||||
|
||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||
|
||||
if (*pos == '\r') {
|
||||
pos += pos[1] == '\n' ? 2 : 1;
|
||||
} else if (*pos == '\n') {
|
||||
pos++;
|
||||
} else if (*pos) {
|
||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||
}
|
||||
return parse_space(pos, true);
|
||||
}
|
||||
|
||||
parse_state parse(const char * src) {
|
||||
try {
|
||||
parse_state state;
|
||||
const char * pos = parse_space(src, true);
|
||||
while (*pos) {
|
||||
pos = parse_rule(state, pos);
|
||||
}
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
}
|
||||
}
|
||||
|
||||
void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
} else {
|
||||
// cop out of encoding UTF-8
|
||||
fprintf(file, "<U+%04X>", c);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_char_element(whisper_grammar_element elem) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_CHAR: return true;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
|
||||
for (auto elem : rule) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
|
||||
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "(%u) ", elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
fprintf(file, "(\"");
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_rule(
|
||||
FILE * file,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule,
|
||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
|
||||
throw std::runtime_error(
|
||||
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
|
||||
}
|
||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
whisper_grammar_element elem = rule[i];
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
throw std::runtime_error(
|
||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||
std::to_string(i));
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
fprintf(file, "| ");
|
||||
break;
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
fprintf(file, "[");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
fprintf(file, "[^");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
fprintf(file, "-");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
break;
|
||||
default:
|
||||
fprintf(file, "] ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_grammar(FILE * file, const parse_state & state) {
|
||||
try {
|
||||
std::map<uint32_t, std::string> symbol_id_names;
|
||||
for (auto kv : state.symbol_ids) {
|
||||
symbol_id_names[kv.second] = kv.first;
|
||||
}
|
||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||
// fprintf(file, "%zu: ", i);
|
||||
// print_rule_binary(file, state.rules[i]);
|
||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
||||
// fprintf(file, "\n");
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
||||
std::vector<const whisper_grammar_element *> ret;
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
29
examples/grammar-parser.h
Normal file
@ -0,0 +1,29 @@
|
||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||
// binary context-free grammar format specified by whisper.h. Supports character
|
||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||
// arithmetic might look like:
|
||||
//
|
||||
// root ::= expr
|
||||
// expr ::= term ([-+*/] term)*
|
||||
// term ::= num | "(" space expr ")" space
|
||||
// num ::= [0-9]+ space
|
||||
// space ::= [ \t\n]*
|
||||
|
||||
#pragma once
|
||||
#include "whisper.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace grammar_parser {
|
||||
struct parse_state {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
|
||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
void print_grammar(FILE * file, const parse_state & state);
|
||||
}
|
@ -62,8 +62,8 @@ struct whisper_params {
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
|
@ -53,6 +53,7 @@ struct whisper_params {
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
int32_t n_gpu_layers = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
@ -90,6 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
@ -134,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7s] number of layers to store in VRAM\n", params.n_gpu_layers);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
@ -248,7 +251,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
@ -268,6 +271,8 @@ int main(int argc, char ** argv) {
|
||||
auto lmparams = llama_model_default_params();
|
||||
if (!params.use_gpu) {
|
||||
lmparams.n_gpu_layers = 0;
|
||||
} else {
|
||||
lmparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
|
||||
|
@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string word;
|
||||
char word[129];
|
||||
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
word[len] = '\0';
|
||||
fin.read((char *) word, len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
|
15
examples/whisper.android.java/.gitignore
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
20
examples/whisper.android.java/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
A sample Android app using java code and [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Modify the modelFilePath in the WhisperService.java
|
||||
6. Modify the sampleFilePath in the WhisperService.java
|
||||
7. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
PS:
|
||||
1. Do not move this android project folder individually to other folders, because this android project folder depends on the files of the whole project.
|
||||
2. The cpp code is compiled during the build process
|
||||
3. If you want to import a compiled cpp project in your Android project, please refer to the https://github.com/litongjava/whisper.cpp.android.java.demo
|
||||
|
||||

|
||||
|
BIN
examples/whisper.android.java/README_files/1.jpg
Normal file
After Width: | Height: | Size: 67 KiB |
1
examples/whisper.android.java/app/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/build
|
58
examples/whisper.android.java/app/build.gradle
Normal file
@ -0,0 +1,58 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
}
|
||||
|
||||
android {
|
||||
compileSdkVersion 30
|
||||
buildToolsVersion '30.0.3'
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.litongjava.whisper.android.java"
|
||||
minSdkVersion 21
|
||||
targetSdkVersion 30
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
cppFlags ""
|
||||
}
|
||||
}
|
||||
ndk {
|
||||
abiFilters 'arm64-v8a', 'armeabi-v7a', 'x86', 'x86_64'
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.debug
|
||||
minifyEnabled true
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "src/main/jni/whisper/CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
ndkVersion "25.2.9519653"
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.appcompat:appcompat:1.1.0'
|
||||
implementation 'com.google.android.material:material:1.1.0'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
|
||||
testImplementation 'junit:junit:4.+'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
||||
|
||||
//litongjava
|
||||
implementation 'com.litongjava:android-view-inject:1.0'
|
||||
implementation 'com.litongjava:jfinal-aop:1.0.1'
|
||||
implementation 'com.litongjava:litongjava-android-utils:1.0.0'
|
||||
}
|
21
examples/whisper.android.java/app/proguard-rules.pro
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
@ -0,0 +1,26 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
|
||||
*/
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class ExampleInstrumentedTest {
|
||||
@Test
|
||||
public void useAppContext() {
|
||||
// Context of the app under test.
|
||||
Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
assertEquals("com.litongjava.whisper.android.java", appContext.getPackageName());
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.litongjava.whisper.android.java">
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:name=".app.App"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.Whisperandroidjava">
|
||||
<activity android:name=".MainActivity">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -0,0 +1,40 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<configuration debug="false" xmlns="http://ch.qos.logback/xml/ns/logback"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://ch.qos.logback/xml/ns/logback https://raw.githubusercontent.com/enricopulatzo/logback-XSD/master/src/main/xsd/logback.xsd
|
||||
http://ch.qos.logback/xml/ns/logback ">
|
||||
<!--Define the storage address of the log file Do not use relative paths in the LogBack configuration. -->
|
||||
<property name="LOG_HOME" value="logs" />
|
||||
<!--Formatted output: %d means the date, %-6level: log level from the left display 6 characters wide, %m: log message, %n is a newline character -->
|
||||
<property name="CONSOLE_LOG_PATTERN"
|
||||
value="%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-6level%logger{0}.%M:%L - %m%n" />
|
||||
|
||||
<!-- console output -->
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
|
||||
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<!-- Generate log files on a daily basis -->
|
||||
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
|
||||
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
|
||||
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
|
||||
</encoder>
|
||||
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
|
||||
<!--File name for log file output -->
|
||||
<fileNamePattern>${LOG_HOME}/project-name-%d{yyyy-MM-dd}.log</fileNamePattern>
|
||||
<!--Maximum size of log file -->
|
||||
<maxHistory>180</maxHistory>
|
||||
</rollingPolicy>
|
||||
<!--日志文件最大的大小 -->
|
||||
<triggeringPolicy class="ch.qos.logback.core.rolling.SizeBasedTriggeringPolicy">
|
||||
<maxFileSize>10MB</maxFileSize>
|
||||
</triggeringPolicy>
|
||||
</appender>
|
||||
<!-- Log output level and source-->
|
||||
<root level="info">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
</configuration>
|
@ -0,0 +1,107 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.view.View;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.android.view.inject.annotation.FindViewById;
|
||||
import com.litongjava.android.view.inject.annotation.FindViewByIdLayout;
|
||||
import com.litongjava.android.view.inject.annotation.OnClick;
|
||||
import com.litongjava.android.view.inject.utils.ViewInjectUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.jfinal.aop.AopManager;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
import com.litongjava.whisper.android.java.task.LoadModelTask;
|
||||
import com.litongjava.whisper.android.java.task.TranscriptionTask;
|
||||
import com.litongjava.whisper.android.java.utils.AssetUtils;
|
||||
import com.whispercpp.java.whisper.WhisperLib;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
|
||||
@FindViewByIdLayout(R.layout.activity_main)
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
|
||||
@FindViewById(R.id.sample_text)
|
||||
private TextView tv;
|
||||
|
||||
Logger log = LoggerFactory.getLogger(this.getClass());
|
||||
private WhisperService whisperService = Aop.get(WhisperService.class);
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
//setContentView(R.layout.activity_main);
|
||||
ViewInjectUtils.injectActivity(this, this);
|
||||
initAopBean();
|
||||
showSystemInfo();
|
||||
}
|
||||
|
||||
private void initAopBean() {
|
||||
Handler mainHandler = new Handler(Looper.getMainLooper());
|
||||
AopManager.me().addSingletonObject(mainHandler);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@OnClick(R.id.loadModelBtn)
|
||||
public void loadModelBtn_OnClick(View v) {
|
||||
Context context = getBaseContext();
|
||||
ThreadUtils.executeByIo(new LoadModelTask(tv));
|
||||
}
|
||||
|
||||
@OnClick(R.id.transcriptSampleBtn)
|
||||
public void transcriptSampleBtn_OnClick(View v) {
|
||||
Context context = getBaseContext();
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
String sampleFilePath = "samples/jfk.wav";
|
||||
File filesDir = context.getFilesDir();
|
||||
File sampleFile = AssetUtils.copyFileIfNotExists(context, filesDir, sampleFilePath);
|
||||
long end = System.currentTimeMillis();
|
||||
String msg = "copy file:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
ThreadUtils.executeByIo(new TranscriptionTask(tv, sampleFile));
|
||||
}
|
||||
|
||||
private void outputMsg(TextView tv, String msg) {
|
||||
tv.append(msg + "\n");
|
||||
log.info(msg);
|
||||
}
|
||||
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@OnClick(R.id.systemInfoBtn)
|
||||
public void systemInfoBtn_OnClick(View v) {
|
||||
showSystemInfo();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void showSystemInfo() {
|
||||
String systemInfo = WhisperLib.getSystemInfo();
|
||||
tv.append(systemInfo + "\n");
|
||||
}
|
||||
|
||||
@OnClick(R.id.clearBtn)
|
||||
public void clearBtn_OnClick(View v) {
|
||||
tv.setText("");
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
protected void onDestroy() {
|
||||
super.onDestroy();
|
||||
whisperService.release();
|
||||
}
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
package com.litongjava.whisper.android.java.app;
|
||||
|
||||
import android.app.Application;
|
||||
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
|
||||
public class App extends Application {
|
||||
@Override
|
||||
public void onCreate() {
|
||||
super.onCreate();
|
||||
Utils.init(this);
|
||||
}
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package com.litongjava.whisper.android.java.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "["+start+" --> "+end+"]:"+sentence;
|
||||
}
|
||||
}
|
@ -0,0 +1,101 @@
|
||||
package com.litongjava.whisper.android.java.services;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.blankj.utilcode.util.ToastUtils;
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
import com.litongjava.android.utils.dialog.AlertDialogUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
import com.litongjava.whisper.android.java.single.LocalWhisper;
|
||||
import com.litongjava.whisper.android.java.utils.WaveEncoder;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
public class WhisperService {
|
||||
private Logger log = LoggerFactory.getLogger(this.getClass());
|
||||
|
||||
private final Object lock = new Object();
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void loadModel(TextView tv) {
|
||||
String modelFilePath = LocalWhisper.modelFilePath;
|
||||
String msg = "load model from :" + modelFilePath + "\n";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
LocalWhisper.INSTANCE.init();
|
||||
long end = System.currentTimeMillis();
|
||||
msg = "model load successful:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
ToastUtils.showLong(msg);
|
||||
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void transcribeSample(TextView tv, File sampleFile) {
|
||||
String msg = "";
|
||||
msg = "transcribe file from :" + sampleFile.getAbsolutePath();
|
||||
outputMsg(tv, msg);
|
||||
|
||||
Long start = System.currentTimeMillis();
|
||||
float[] audioData = new float[0]; // 读取音频样本
|
||||
try {
|
||||
audioData = WaveEncoder.decodeWaveFile(sampleFile);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return;
|
||||
}
|
||||
long end = System.currentTimeMillis();
|
||||
msg = "decode wave file:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
start = System.currentTimeMillis();
|
||||
List<WhisperSegment> transcription = null;
|
||||
try {
|
||||
//transcription = LocalWhisper.INSTANCE.transcribeData(audioData);
|
||||
transcription = LocalWhisper.INSTANCE.transcribeDataWithTime(audioData);
|
||||
} catch (ExecutionException e) {
|
||||
e.printStackTrace();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
end = System.currentTimeMillis();
|
||||
if(transcription!=null){
|
||||
ToastUtils.showLong(transcription.toString());
|
||||
msg = "Transcript successful:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
outputMsg(tv, transcription.toString());
|
||||
|
||||
}else{
|
||||
msg = "Transcript failed:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void outputMsg(TextView tv, String msg) {
|
||||
log.info(msg);
|
||||
if(tv!=null){
|
||||
Aop.get(Handler.class).post(()->{ tv.append(msg + "\n");});
|
||||
}
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void release() {
|
||||
//noting to do
|
||||
}
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
package com.litongjava.whisper.android.java.single;
|
||||
|
||||
import android.app.Application;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.blankj.utilcode.util.ToastUtils;
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
import com.litongjava.whisper.android.java.utils.AssetUtils;
|
||||
import com.whispercpp.java.whisper.WhisperContext;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public enum LocalWhisper {
|
||||
INSTANCE;
|
||||
|
||||
public static final String modelFilePath = "models/ggml-tiny.bin";
|
||||
private WhisperContext whisperContext;
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
LocalWhisper() {
|
||||
Application context = Utils.getApp();
|
||||
File filesDir = context.getFilesDir();
|
||||
File modelFile = AssetUtils.copyFileIfNotExists(context, filesDir, modelFilePath);
|
||||
String realModelFilePath = modelFile.getAbsolutePath();
|
||||
whisperContext = WhisperContext.createContextFromFile(realModelFilePath);
|
||||
}
|
||||
|
||||
public synchronized String transcribeData(float[] data) throws ExecutionException, InterruptedException {
|
||||
if(whisperContext==null){
|
||||
toastModelLoading();
|
||||
return null;
|
||||
}else{
|
||||
return whisperContext.transcribeData(data);
|
||||
}
|
||||
}
|
||||
|
||||
private static void toastModelLoading() {
|
||||
Aop.get(Handler.class).post(()->{
|
||||
ToastUtils.showShort("please wait for model loading");
|
||||
});
|
||||
}
|
||||
|
||||
public List<WhisperSegment> transcribeDataWithTime(float[] audioData) throws ExecutionException, InterruptedException {
|
||||
if(whisperContext==null){
|
||||
toastModelLoading();
|
||||
return null;
|
||||
}else{
|
||||
return whisperContext.transcribeDataWithTime(audioData);
|
||||
}
|
||||
}
|
||||
|
||||
public void init() {
|
||||
//noting to do.but init
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
package com.litongjava.whisper.android.java.task;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class LoadModelTask extends ThreadUtils.Task<Object> {
|
||||
private final TextView tv;
|
||||
public LoadModelTask(TextView tv) {
|
||||
this.tv = tv;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doInBackground() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
Aop.get(WhisperService.class).loadModel(tv);
|
||||
}else{
|
||||
Aop.get(Handler.class).post(()->{
|
||||
tv.append("not supported android devices");
|
||||
});
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(Object result) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFail(Throwable t) {
|
||||
}
|
||||
}
|
@ -0,0 +1,44 @@
|
||||
package com.litongjava.whisper.android.java.task;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class TranscriptionTask extends ThreadUtils.Task<Object> {
|
||||
private final TextView tv;
|
||||
private final File sampleFile;
|
||||
|
||||
public TranscriptionTask(TextView tv, File sampleFile) {
|
||||
this.tv = tv;
|
||||
this.sampleFile = sampleFile;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doInBackground() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
Aop.get(WhisperService.class).transcribeSample(tv, sampleFile);
|
||||
}else{
|
||||
tv.append("not supported android devices");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(Object result) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFail(Throwable t) {
|
||||
}
|
||||
}
|
@ -0,0 +1,91 @@
|
||||
package com.litongjava.whisper.android.java.utils;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
public class AssetUtils {
|
||||
private static Logger log = LoggerFactory.getLogger(AssetUtils.class);
|
||||
|
||||
public static File copyFileIfNotExists(Context context, File distDir, String filename) {
|
||||
File dstFile = new File(distDir, filename);
|
||||
if (dstFile.exists()) {
|
||||
return dstFile;
|
||||
} else {
|
||||
File parentFile = dstFile.getParentFile();
|
||||
log.info("parentFile:{}", parentFile);
|
||||
if (!parentFile.exists()) {
|
||||
parentFile.mkdirs();
|
||||
}
|
||||
AssetUtils.copyFileFromAssets(context, filename, dstFile);
|
||||
}
|
||||
return dstFile;
|
||||
}
|
||||
|
||||
public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
|
||||
if (srcDir.isEmpty() || dstDir.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (!new File(dstDir).exists()) {
|
||||
new File(dstDir).mkdirs();
|
||||
}
|
||||
for (String fileName : appCtx.getAssets().list(srcDir)) {
|
||||
String srcSubPath = srcDir + File.separator + fileName;
|
||||
String dstSubPath = dstDir + File.separator + fileName;
|
||||
if (new File(srcSubPath).isDirectory()) {
|
||||
copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
|
||||
} else {
|
||||
copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
|
||||
File dstFile = new File(dstPath);
|
||||
copyFileFromAssets(appCtx, srcPath, dstFile);
|
||||
}
|
||||
|
||||
public static void copyFileFromAssets(Context appCtx, String srcPath, File dstFile) {
|
||||
if (srcPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
InputStream is = null;
|
||||
OutputStream os = null;
|
||||
try {
|
||||
is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
|
||||
|
||||
os = new BufferedOutputStream(new FileOutputStream(dstFile));
|
||||
byte[] buffer = new byte[1024];
|
||||
int length = 0;
|
||||
while ((length = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, length);
|
||||
}
|
||||
} catch (FileNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
try {
|
||||
os.close();
|
||||
is.close();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -0,0 +1,105 @@
|
||||
package com.litongjava.whisper.android.java.utils;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.ShortBuffer;
|
||||
|
||||
public class WaveEncoder {
|
||||
|
||||
public static float[] decodeWaveFile(File file) throws IOException {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
try (FileInputStream fis = new FileInputStream(file)) {
|
||||
byte[] buffer = new byte[1024];
|
||||
int bytesRead;
|
||||
while ((bytesRead = fis.read(buffer)) != -1) {
|
||||
baos.write(buffer, 0, bytesRead);
|
||||
}
|
||||
}
|
||||
ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
|
||||
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
int channel = byteBuffer.getShort(22);
|
||||
byteBuffer.position(44);
|
||||
|
||||
ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
|
||||
short[] shortArray = new short[shortBuffer.limit()];
|
||||
shortBuffer.get(shortArray);
|
||||
|
||||
float[] output = new float[shortArray.length / channel];
|
||||
|
||||
for (int index = 0; index < output.length; index++) {
|
||||
if (channel == 1) {
|
||||
output[index] = Math.max(-1f, Math.min(1f, shortArray[index] / 32767.0f));
|
||||
} else {
|
||||
output[index] = Math.max(-1f, Math.min(1f, (shortArray[2 * index] + shortArray[2 * index + 1]) / 32767.0f / 2.0f));
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
public static void encodeWaveFile(File file, short[] data) throws IOException {
|
||||
try (FileOutputStream fos = new FileOutputStream(file)) {
|
||||
fos.write(headerBytes(data.length * 2));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(data.length * 2);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.asShortBuffer().put(data);
|
||||
|
||||
byte[] bytes = new byte[buffer.limit()];
|
||||
buffer.get(bytes);
|
||||
|
||||
fos.write(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] headerBytes(int totalLength) {
|
||||
if (totalLength < 44)
|
||||
throw new IllegalArgumentException("Total length must be at least 44 bytes");
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(44);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
buffer.put((byte) 'R');
|
||||
buffer.put((byte) 'I');
|
||||
buffer.put((byte) 'F');
|
||||
buffer.put((byte) 'F');
|
||||
|
||||
buffer.putInt(totalLength - 8);
|
||||
|
||||
buffer.put((byte) 'W');
|
||||
buffer.put((byte) 'A');
|
||||
buffer.put((byte) 'V');
|
||||
buffer.put((byte) 'E');
|
||||
|
||||
buffer.put((byte) 'f');
|
||||
buffer.put((byte) 'm');
|
||||
buffer.put((byte) 't');
|
||||
buffer.put((byte) ' ');
|
||||
|
||||
buffer.putInt(16);
|
||||
buffer.putShort((short) 1);
|
||||
buffer.putShort((short) 1);
|
||||
buffer.putInt(16000);
|
||||
buffer.putInt(32000);
|
||||
buffer.putShort((short) 2);
|
||||
buffer.putShort((short) 16);
|
||||
|
||||
buffer.put((byte) 'd');
|
||||
buffer.put((byte) 'a');
|
||||
buffer.put((byte) 't');
|
||||
buffer.put((byte) 'a');
|
||||
|
||||
buffer.putInt(totalLength - 44);
|
||||
buffer.position(0);
|
||||
|
||||
byte[] bytes = new byte[buffer.limit()];
|
||||
buffer.get(bytes);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
}
|
@ -0,0 +1,121 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class CpuInfo {
|
||||
private static final String LOG_TAG = "WhisperCpuConfig";
|
||||
|
||||
private List<String> lines;
|
||||
|
||||
public CpuInfo(List<String> lines) {
|
||||
this.lines = lines;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public int getHighPerfCpuCount0() {
|
||||
try {
|
||||
return getHighPerfCpuCountByFrequencies();
|
||||
} catch (Exception e) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e);
|
||||
return getHighPerfCpuCountByVariant();
|
||||
}
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int getHighPerfCpuCountByFrequencies() {
|
||||
List<Integer> frequencies = getCpuValues("processor", line -> {
|
||||
try {
|
||||
return getMaxCpuFrequency(Integer.parseInt(line.trim()));
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
);
|
||||
Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): " + binnedValues(frequencies));
|
||||
return countDroppingMin(frequencies);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int getHighPerfCpuCountByVariant() {
|
||||
List<Integer> variants = getCpuValues("CPU variant", line -> Integer.parseInt(line.trim().substring(line.indexOf("0x") + 2), 16));
|
||||
Log.d(LOG_TAG, "Binned cpu variants (variant, count): " + binnedValues(variants));
|
||||
return countKeepingMin(variants);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private Map<Integer, Integer> binnedValues(List<Integer> values) {
|
||||
Map<Integer, Integer> countMap = new HashMap<>();
|
||||
for (int value : values) {
|
||||
countMap.put(value, countMap.getOrDefault(value, 0) + 1);
|
||||
}
|
||||
return countMap;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private List<Integer> getCpuValues(String property, Mapper mapper) {
|
||||
List<Integer> values = new ArrayList<>();
|
||||
for (String line : lines) {
|
||||
if (line.startsWith(property)) {
|
||||
values.add(mapper.map(line.substring(line.indexOf(':') + 1)));
|
||||
}
|
||||
}
|
||||
values.sort(Integer::compareTo);
|
||||
return values;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int countDroppingMin(List<Integer> values) {
|
||||
int min = values.stream().mapToInt(i -> i).min().orElse(Integer.MAX_VALUE);
|
||||
return (int) values.stream().filter(value -> value > min).count();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int countKeepingMin(List<Integer> values) {
|
||||
int min = values.stream().mapToInt(i -> i).min().orElse(Integer.MAX_VALUE);
|
||||
return (int) values.stream().filter(value -> value.equals(min)).count();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public static int getHighPerfCpuCount() {
|
||||
try {
|
||||
return readCpuInfo().getHighPerfCpuCount0();
|
||||
} catch (Exception e) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU info", e);
|
||||
return Math.max(Runtime.getRuntime().availableProcessors() - 4, 0);
|
||||
}
|
||||
}
|
||||
|
||||
private static CpuInfo readCpuInfo() throws IOException {
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader("/proc/cpuinfo"))) {
|
||||
List<String> lines = new ArrayList<>();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
lines.add(line);
|
||||
}
|
||||
return new CpuInfo(lines);
|
||||
}
|
||||
}
|
||||
|
||||
private static int getMaxCpuFrequency(int cpuIndex) throws IOException {
|
||||
String path = "/sys/devices/system/cpu/cpu" + cpuIndex + "/cpufreq/cpuinfo_max_freq";
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
return Integer.parseInt(reader.readLine());
|
||||
}
|
||||
}
|
||||
|
||||
private interface Mapper {
|
||||
int map(String line);
|
||||
}
|
||||
}
|
@ -0,0 +1,138 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public class WhisperContext {
|
||||
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
private long ptr;
|
||||
private final ExecutorService executorService;
|
||||
|
||||
private WhisperContext(long ptr) {
|
||||
this.ptr = ptr;
|
||||
this.executorService = Executors.newSingleThreadExecutor();
|
||||
}
|
||||
|
||||
public String transcribeData(float[] data) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(new Callable<String>() {
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
public String call() throws Exception {
|
||||
if (ptr == 0L) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
int numThreads = WhisperCpuConfig.getPreferredThreadCount();
|
||||
Log.d(LOG_TAG, "Selecting " + numThreads + " threads");
|
||||
|
||||
StringBuilder result = new StringBuilder();
|
||||
synchronized (this) {
|
||||
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data);
|
||||
int textCount = WhisperLib.getTextSegmentCount(ptr);
|
||||
for (int i = 0; i < textCount; i++) {
|
||||
String sentence = WhisperLib.getTextSegment(ptr, i);
|
||||
result.append(sentence);
|
||||
}
|
||||
}
|
||||
return result.toString();
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
public List<WhisperSegment> transcribeDataWithTime(float[] data) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(new Callable<List<WhisperSegment>>() {
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
public List<WhisperSegment> call() throws Exception {
|
||||
if (ptr == 0L) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
int numThreads = WhisperCpuConfig.getPreferredThreadCount();
|
||||
Log.d(LOG_TAG, "Selecting " + numThreads + " threads");
|
||||
|
||||
List<WhisperSegment> segments = new ArrayList<>();
|
||||
synchronized (this) {
|
||||
// StringBuilder result = new StringBuilder();
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data);
|
||||
int textCount = WhisperLib.getTextSegmentCount(ptr);
|
||||
for (int i = 0; i < textCount; i++) {
|
||||
long start = WhisperLib.getTextSegmentT0(ptr, i);
|
||||
String sentence = WhisperLib.getTextSegment(ptr, i);
|
||||
long end = WhisperLib.getTextSegmentT1(ptr, i);
|
||||
// result.append();
|
||||
segments.add(new WhisperSegment(start, end, sentence));
|
||||
|
||||
}
|
||||
// return result.toString();
|
||||
}
|
||||
return segments;
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public String benchMemory(int nthreads) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(() -> WhisperLib.benchMemcpy(nthreads)).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public String benchGgmlMulMat(int nthreads) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(() -> WhisperLib.benchGgmlMulMat(nthreads)).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void release() throws ExecutionException, InterruptedException {
|
||||
executorService.submit(() -> {
|
||||
if (ptr != 0L) {
|
||||
WhisperLib.freeContext(ptr);
|
||||
ptr = 0;
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromFile(String filePath) {
|
||||
long ptr = WhisperLib.initContext(filePath);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context with path " + filePath);
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromInputStream(InputStream stream) {
|
||||
long ptr = WhisperLib.initContextFromInputStream(stream);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context from input stream");
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromAsset(AssetManager assetManager, String assetPath) {
|
||||
long ptr = WhisperLib.initContextFromAsset(assetManager, assetPath);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context from asset " + assetPath);
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static String getSystemInfo() {
|
||||
return WhisperLib.getSystemInfo();
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
public class WhisperCpuConfig {
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public static int getPreferredThreadCount() {
|
||||
return Math.max(CpuInfo.getHighPerfCpuCount(), 2);
|
||||
}
|
||||
}
|
@ -0,0 +1,75 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public class WhisperLib {
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
|
||||
static {
|
||||
|
||||
Log.d(LOG_TAG, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
|
||||
boolean loadVfpv4 = false;
|
||||
boolean loadV8fp16 = false;
|
||||
if (WhisperUtils.isArmEabiV7a()) {
|
||||
String cpuInfo = WhisperUtils.cpuInfo();
|
||||
if (cpuInfo != null) {
|
||||
Log.d(LOG_TAG, "CPU info: " + cpuInfo);
|
||||
if (cpuInfo.contains("vfpv4")) {
|
||||
Log.d(LOG_TAG, "CPU supports vfpv4");
|
||||
loadVfpv4 = true;
|
||||
}
|
||||
}
|
||||
} else if (WhisperUtils.isArmEabiV8a()) {
|
||||
String cpuInfo = WhisperUtils.cpuInfo();
|
||||
if (cpuInfo != null) {
|
||||
Log.d(LOG_TAG, "CPU info: " + cpuInfo);
|
||||
if (cpuInfo.contains("fphp")) {
|
||||
Log.d(LOG_TAG, "CPU supports fp16 arithmetic");
|
||||
loadV8fp16 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (loadVfpv4) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so");
|
||||
System.loadLibrary("whisper_vfpv4");
|
||||
} else if (loadV8fp16) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so");
|
||||
System.loadLibrary("whisper_v8fp16_va");
|
||||
} else {
|
||||
Log.d(LOG_TAG, "Loading libwhisper.so");
|
||||
System.loadLibrary("whisper");
|
||||
}
|
||||
}
|
||||
|
||||
public static native long initContextFromInputStream(InputStream inputStream);
|
||||
|
||||
public static native long initContextFromAsset(AssetManager assetManager, String assetPath);
|
||||
|
||||
public static native long initContext(String modelPath);
|
||||
|
||||
public static native void freeContext(long contextPtr);
|
||||
|
||||
public static native void fullTranscribe(long contextPtr, int numThreads, float[] audioData);
|
||||
|
||||
public static native int getTextSegmentCount(long contextPtr);
|
||||
|
||||
public static native String getTextSegment(long contextPtr, int index);
|
||||
|
||||
public static native long getTextSegmentT0(long contextPtr, int index);
|
||||
|
||||
public static native long getTextSegmentT1(long contextPtr, int index);
|
||||
|
||||
public static native String getSystemInfo();
|
||||
|
||||
public static native String benchMemcpy(int nthread);
|
||||
|
||||
public static native String benchGgmlMulMat(int nthread);
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class WhisperUtils {
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
|
||||
|
||||
public static boolean isArmEabiV7a() {
|
||||
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a");
|
||||
}
|
||||
|
||||
public static boolean isArmEabiV8a() {
|
||||
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static String cpuInfo() {
|
||||
try {
|
||||
Path path = new File("/proc/cpuinfo").toPath();
|
||||
return new String(java.nio.file.Files.readAllBytes(path));
|
||||
} catch (Exception e) {
|
||||
Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
|
||||
project(whisper.cpp)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
|
||||
|
||||
set(
|
||||
SOURCE_FILES
|
||||
${WHISPER_LIB_DIR}/ggml.c
|
||||
${WHISPER_LIB_DIR}/ggml-alloc.c
|
||||
${WHISPER_LIB_DIR}/ggml-backend.c
|
||||
${WHISPER_LIB_DIR}/ggml-quants.c
|
||||
${WHISPER_LIB_DIR}/whisper.cpp
|
||||
${CMAKE_SOURCE_DIR}/jni.c
|
||||
)
|
||||
|
||||
find_library(LOG_LIB log)
|
||||
|
||||
function(build_library target_name)
|
||||
add_library(
|
||||
${target_name}
|
||||
SHARED
|
||||
${SOURCE_FILES}
|
||||
)
|
||||
|
||||
target_link_libraries(${target_name} ${LOG_LIB} android)
|
||||
|
||||
if (${target_name} STREQUAL "whisper_v8fp16_va")
|
||||
target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16)
|
||||
elseif (${target_name} STREQUAL "whisper_vfpv4")
|
||||
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
|
||||
endif ()
|
||||
|
||||
if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
|
||||
target_compile_options(${target_name} PRIVATE -O3)
|
||||
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
|
||||
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
|
||||
|
||||
#target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
|
||||
#target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
|
||||
#target_link_options(${target_name} PRIVATE -flto)
|
||||
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
build_library("whisper") # Default target
|
||||
|
||||
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
|
||||
build_library("whisper_v8fp16_va")
|
||||
elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a")
|
||||
build_library("whisper_vfpv4")
|
||||
endif ()
|
||||
|
||||
include_directories(${WHISPER_LIB_DIR})
|
257
examples/whisper.android.java/app/src/main/jni/whisper/jni.c
Normal file
@ -0,0 +1,257 @@
|
||||
#include <jni.h>
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <string.h>
|
||||
#include "whisper.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define TAG "JNI"
|
||||
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
|
||||
static inline int min(int a, int b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
static inline int max(int a, int b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
struct input_stream_context {
|
||||
size_t offset;
|
||||
JNIEnv * env;
|
||||
jobject thiz;
|
||||
jobject input_stream;
|
||||
|
||||
jmethodID mid_available;
|
||||
jmethodID mid_read;
|
||||
};
|
||||
|
||||
size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
|
||||
|
||||
jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
|
||||
|
||||
jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
|
||||
|
||||
if (size_to_copy != read_size || size_to_copy != n_read) {
|
||||
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
|
||||
}
|
||||
|
||||
jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
|
||||
memcpy(output, byte_array_elements, size_to_copy);
|
||||
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
|
||||
|
||||
(*is->env)->DeleteLocalRef(is->env, byte_array);
|
||||
|
||||
is->offset += size_to_copy;
|
||||
|
||||
return size_to_copy;
|
||||
}
|
||||
bool inputStreamEof(void * ctx) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
return result <= 0;
|
||||
}
|
||||
void inputStreamClose(void * ctx) {
|
||||
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContextFromInputStream(
|
||||
JNIEnv *env, jobject thiz, jobject input_stream) {
|
||||
UNUSED(thiz);
|
||||
|
||||
struct whisper_context *context = NULL;
|
||||
struct whisper_model_loader loader = {};
|
||||
struct input_stream_context inp_ctx = {};
|
||||
|
||||
inp_ctx.offset = 0;
|
||||
inp_ctx.env = env;
|
||||
inp_ctx.thiz = thiz;
|
||||
inp_ctx.input_stream = input_stream;
|
||||
|
||||
jclass cls = (*env)->GetObjectClass(env, input_stream);
|
||||
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
|
||||
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
|
||||
|
||||
loader.context = &inp_ctx;
|
||||
loader.read = inputStreamRead;
|
||||
loader.eof = inputStreamEof;
|
||||
loader.close = inputStreamClose;
|
||||
|
||||
loader.eof(loader.context);
|
||||
|
||||
context = whisper_init(&loader);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
static size_t asset_read(void *ctx, void *output, size_t read_size) {
|
||||
return AAsset_read((AAsset *) ctx, output, read_size);
|
||||
}
|
||||
|
||||
static bool asset_is_eof(void *ctx) {
|
||||
return AAsset_getRemainingLength64((AAsset *) ctx) <= 0;
|
||||
}
|
||||
|
||||
static void asset_close(void *ctx) {
|
||||
AAsset_close((AAsset *) ctx);
|
||||
}
|
||||
|
||||
static struct whisper_context *whisper_init_from_asset(
|
||||
JNIEnv *env,
|
||||
jobject assetManager,
|
||||
const char *asset_path
|
||||
) {
|
||||
LOGI("Loading model from asset '%s'\n", asset_path);
|
||||
AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
|
||||
AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING);
|
||||
if (!asset) {
|
||||
LOGW("Failed to open '%s'\n", asset_path);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_model_loader loader = {
|
||||
.context = asset,
|
||||
.read = &asset_read,
|
||||
.eof = &asset_is_eof,
|
||||
.close = &asset_close
|
||||
};
|
||||
|
||||
return whisper_init(&loader);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContextFromAsset(
|
||||
JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL);
|
||||
context = whisper_init_from_asset(env, assetManager, asset_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContext(
|
||||
JNIEnv *env, jobject thiz, jstring model_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
|
||||
context = whisper_init_from_file(model_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_freeContext(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
whisper_free(context);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_fullTranscribe(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||
|
||||
// The below adapted from the Objective-C iOS sample
|
||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
params.print_realtime = true;
|
||||
params.print_progress = false;
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = false;
|
||||
params.language = "en";
|
||||
params.n_threads = num_threads;
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = false;
|
||||
|
||||
whisper_reset_timings(context);
|
||||
|
||||
LOGI("About to run whisper_full");
|
||||
if (whisper_full(context, params, audio_data_arr, audio_data_length) != 0) {
|
||||
LOGI("Failed to run the model");
|
||||
} else {
|
||||
whisper_print_timings(context);
|
||||
}
|
||||
(*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentCount(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
return whisper_full_n_segments(context);
|
||||
}
|
||||
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegment(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const char *text = whisper_full_get_segment_text(context, index);
|
||||
jstring string = (*env)->NewStringUTF(env, text);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentT0(JNIEnv *env, jobject thiz,jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const int64_t t0 = whisper_full_get_segment_t0(context, index);
|
||||
return (jlong)t0;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentT1(JNIEnv *env, jobject thiz,jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const int64_t t1 = whisper_full_get_segment_t1(context, index);
|
||||
return (jlong)t1;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getSystemInfo(
|
||||
JNIEnv *env, jobject thiz
|
||||
) {
|
||||
UNUSED(thiz);
|
||||
const char *sysinfo = whisper_print_system_info();
|
||||
jstring string = (*env)->NewStringUTF(env, sysinfo);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_benchMemcpy(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_benchGgmlMulMat(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
|
||||
}
|
||||
|
@ -0,0 +1,30 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="85.84757"
|
||||
android:endY="92.4963"
|
||||
android:startX="42.9492"
|
||||
android:startY="49.59793"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||
android:strokeWidth="1"
|
||||
android:strokeColor="#00000000" />
|
||||
</vector>
|
@ -0,0 +1,170 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
@ -0,0 +1,57 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:orientation="vertical"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content">
|
||||
|
||||
<Button
|
||||
android:id="@+id/systemInfoBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="System Info" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/loadModelBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Load model" />
|
||||
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content">
|
||||
|
||||
<Button
|
||||
android:id="@+id/transcriptSampleBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Transcribe sample" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/clearBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Clear" />
|
||||
</LinearLayout>
|
||||
|
||||
<TextView
|
||||
android:id="@+id/sample_text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Hello World!"
|
||||
app:layout_constraintBottom_toBottomOf="parent"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintRight_toRightOf="parent"
|
||||
app:layout_constraintTop_toTopOf="parent"
|
||||
android:scrollbarAlwaysDrawHorizontalTrack="true"
|
||||
android:maxLines="999"/>
|
||||
|
||||
</LinearLayout>
|
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
After Width: | Height: | Size: 3.5 KiB |
After Width: | Height: | Size: 5.2 KiB |
After Width: | Height: | Size: 2.6 KiB |
After Width: | Height: | Size: 3.3 KiB |
After Width: | Height: | Size: 4.8 KiB |
After Width: | Height: | Size: 7.3 KiB |
After Width: | Height: | Size: 7.7 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 16 KiB |
@ -0,0 +1,16 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.Whisperandroidjava" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_200</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/black</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_200</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
@ -0,0 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">whisper.android.java</string>
|
||||
</resources>
|
@ -0,0 +1,16 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.Whisperandroidjava" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_500</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/white</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_700</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
@ -0,0 +1,17 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
|
||||
*/
|
||||
public class ExampleUnitTest {
|
||||
@Test
|
||||
public void addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2);
|
||||
}
|
||||
}
|
24
examples/whisper.android.java/build.gradle
Normal file
@ -0,0 +1,24 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
buildscript {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
dependencies {
|
||||
classpath "com.android.tools.build:gradle:4.1.3"
|
||||
|
||||
// NOTE: Do not place your application dependencies here; they belong
|
||||
// in the individual module build.gradle files
|
||||
}
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
task clean(type: Delete) {
|
||||
delete rootProject.buildDir
|
||||
}
|
19
examples/whisper.android.java/gradle.properties
Normal file
@ -0,0 +1,19 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app"s APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Automatically convert third-party libraries to use AndroidX
|
||||
android.enableJetifier=true
|
BIN
examples/whisper.android.java/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
6
examples/whisper.android.java/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
#Fri Oct 20 11:07:15 HST 2023
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-all.zip
|
172
examples/whisper.android.java/gradlew
vendored
Normal file
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS=""
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin, switch paths to Windows format before running java
|
||||
if $cygwin ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=$((i+1))
|
||||
done
|
||||
case $i in
|
||||
(0) set -- ;;
|
||||
(1) set -- "$args0" ;;
|
||||
(2) set -- "$args0" "$args1" ;;
|
||||
(3) set -- "$args0" "$args1" "$args2" ;;
|
||||
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=$(save "$@")
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
|
||||
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
|
||||
cd "$(dirname "$0")"
|
||||
fi
|
||||
|
||||
exec "$JAVACMD" "$@"
|
84
examples/whisper.android.java/gradlew.bat
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS=
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:init
|
||||
@rem Get command-line arguments, handling Windows variants
|
||||
|
||||
if not "%OS%" == "Windows_NT" goto win9xME_args
|
||||
|
||||
:win9xME_args
|
||||
@rem Slurp the command line arguments.
|
||||
set CMD_LINE_ARGS=
|
||||
set _SKIP=2
|
||||
|
||||
:win9xME_args_slurp
|
||||
if "x%~1" == "x" goto execute
|
||||
|
||||
set CMD_LINE_ARGS=%*
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
2
examples/whisper.android.java/settings.gradle
Normal file
@ -0,0 +1,2 @@
|
||||
include ':app'
|
||||
rootProject.name = "whisper.android.java"
|
2
examples/whisper.android/.idea/gradle.xml
generated
@ -4,6 +4,7 @@
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
|
||||
<option name="modules">
|
||||
@ -13,6 +14,7 @@
|
||||
</set>
|
||||
</option>
|
||||
<option name="resolveExternalAnnotations" value="false" />
|
||||
<option name="resolveModulePerSourceSet" value="false" />
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
|
@ -1,4 +1,5 @@
|
||||
import Foundation
|
||||
import whisper
|
||||
|
||||
enum WhisperError: Error {
|
||||
case couldNotInitializeContext
|
||||
|
@ -1,4 +0,0 @@
|
||||
//
|
||||
// Use this file to import your target's public headers that you would like to expose to Swift.
|
||||
//
|
||||
#import "whisper.h"
|
@ -15,16 +15,9 @@
|
||||
0AAC5D9B29539CCF003032C3 /* WhisperCppDemoApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9A29539CCF003032C3 /* WhisperCppDemoApp.swift */; };
|
||||
0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9C29539CCF003032C3 /* ContentView.swift */; };
|
||||
0AAC5D9F29539CD0003032C3 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9E29539CD0003032C3 /* Assets.xcassets */; };
|
||||
0AAC5DA329539CD0003032C3 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */; };
|
||||
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC729539EB0003032C3 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DGGML_USE_METAL -Wno-shorten-64-to-32"; }; };
|
||||
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL -Wno-shorten-64-to-32"; }; };
|
||||
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
|
||||
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
|
||||
18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE14C2AF555FA0044A204 /* ggml-backend.c */; };
|
||||
18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1512AF555FA0044A204 /* ggml-quants.c */; };
|
||||
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
|
||||
7FCB08262ACFA3A400AF3530 /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FCB08252ACFA3A400AF3530 /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
|
||||
7FCB08282ACFA48500AF3530 /* ggml-metal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 7FCB08272ACFA48500AF3530 /* ggml-metal.metal */; };
|
||||
E3F92DC52AFA8E3800A6A9D4 /* whisper in Frameworks */ = {isa = PBXBuildFile; productRef = E3F92DC42AFA8E3800A6A9D4 /* whisper */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
@ -38,25 +31,9 @@
|
||||
0AAC5D9C29539CCF003032C3 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
|
||||
0AAC5D9E29539CD0003032C3 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||
0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = "<group>"; };
|
||||
0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
|
||||
0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = "<group>"; };
|
||||
0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = "<group>"; };
|
||||
0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = "<group>"; };
|
||||
0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = "<group>"; };
|
||||
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
|
||||
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
|
||||
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
|
||||
18ABE14C2AF555FA0044A204 /* ggml-backend.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-backend.c"; sourceTree = "<group>"; };
|
||||
18ABE14D2AF555FA0044A204 /* ggml-backend.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-backend.h"; sourceTree = "<group>"; };
|
||||
18ABE14E2AF555FA0044A204 /* ggml-backend-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-backend-impl.h"; sourceTree = "<group>"; };
|
||||
18ABE14F2AF555FA0044A204 /* ggml-quants.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-quants.h"; sourceTree = "<group>"; };
|
||||
18ABE1502AF555FA0044A204 /* ggml-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-impl.h"; sourceTree = "<group>"; };
|
||||
18ABE1512AF555FA0044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-quants.c"; sourceTree = "<group>"; };
|
||||
18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
|
||||
18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
|
||||
7FCB081E2ACFA04400AF3530 /* ggml-metal.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-metal.h"; sourceTree = "<group>"; };
|
||||
7FCB08252ACFA3A400AF3530 /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "ggml-metal.m"; sourceTree = "<group>"; };
|
||||
7FCB08272ACFA48500AF3530 /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = "ggml-metal.metal"; sourceTree = "<group>"; };
|
||||
E3F92DC22AFA8DD800A6A9D4 /* whisper.cpp */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = whisper.cpp; path = ../..; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
@ -64,6 +41,7 @@
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
E3F92DC52AFA8E3800A6A9D4 /* whisper in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
@ -99,11 +77,12 @@
|
||||
0AAC5D8E29539CCF003032C3 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E3F92DC22AFA8DD800A6A9D4 /* whisper.cpp */,
|
||||
0A8E48FF2954B3F100704C1B /* README.md */,
|
||||
0AAC5DC529539E89003032C3 /* whisper.cpp */,
|
||||
0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */,
|
||||
0AAC5D9929539CCF003032C3 /* whisper.swiftui.demo */,
|
||||
0AAC5D9829539CCF003032C3 /* Products */,
|
||||
E3F92DC32AFA8E3800A6A9D4 /* Frameworks */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
@ -128,42 +107,9 @@
|
||||
path = whisper.swiftui.demo;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
0AAC5DA129539CD0003032C3 /* Preview Content */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */,
|
||||
);
|
||||
name = "Preview Content";
|
||||
path = "../Preview Content";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
7FCB08272ACFA48500AF3530 /* ggml-metal.metal */,
|
||||
7FCB081E2ACFA04400AF3530 /* ggml-metal.h */,
|
||||
7FCB08252ACFA3A400AF3530 /* ggml-metal.m */,
|
||||
18ABE14E2AF555FA0044A204 /* ggml-backend-impl.h */,
|
||||
18ABE14C2AF555FA0044A204 /* ggml-backend.c */,
|
||||
18ABE14D2AF555FA0044A204 /* ggml-backend.h */,
|
||||
18ABE1502AF555FA0044A204 /* ggml-impl.h */,
|
||||
18ABE1512AF555FA0044A204 /* ggml-quants.c */,
|
||||
18ABE14F2AF555FA0044A204 /* ggml-quants.h */,
|
||||
18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
|
||||
18AED4802AB21F2B009D854F /* ggml-alloc.h */,
|
||||
0AAC5DC929539EB0003032C3 /* ggml.c */,
|
||||
0AAC5DCA29539EB0003032C3 /* ggml.h */,
|
||||
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
|
||||
0AAC5DC829539EB0003032C3 /* whisper.h */,
|
||||
);
|
||||
name = whisper.cpp;
|
||||
path = ../..;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */,
|
||||
0AAC5DD02953A394003032C3 /* LibWhisper.swift */,
|
||||
);
|
||||
path = whisper.cpp.swift;
|
||||
@ -182,11 +128,17 @@
|
||||
children = (
|
||||
0AAC5D9E29539CD0003032C3 /* Assets.xcassets */,
|
||||
0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */,
|
||||
0AAC5DA129539CD0003032C3 /* Preview Content */,
|
||||
);
|
||||
path = "Supporting files";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E3F92DC32AFA8E3800A6A9D4 /* Frameworks */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
);
|
||||
name = Frameworks;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
@ -203,6 +155,9 @@
|
||||
dependencies = (
|
||||
);
|
||||
name = whisper.swiftui;
|
||||
packageProductDependencies = (
|
||||
E3F92DC42AFA8E3800A6A9D4 /* whisper */,
|
||||
);
|
||||
productName = WhisperCppDemo;
|
||||
productReference = 0AAC5D9729539CCF003032C3 /* whisper.swiftui.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
@ -247,7 +202,6 @@
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
0AA751482953AC2E001EE061 /* samples in Resources */,
|
||||
0AAC5DA329539CD0003032C3 /* Preview Assets.xcassets in Resources */,
|
||||
0A8E49002954B3F100704C1B /* README.md in Resources */,
|
||||
0AA751492953AC2E001EE061 /* models in Resources */,
|
||||
0AAC5D9F29539CD0003032C3 /* Assets.xcassets in Resources */,
|
||||
@ -263,17 +217,10 @@
|
||||
files = (
|
||||
0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */,
|
||||
0AAC5D9B29539CCF003032C3 /* WhisperCppDemoApp.swift in Sources */,
|
||||
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */,
|
||||
18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */,
|
||||
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */,
|
||||
7FCB08282ACFA48500AF3530 /* ggml-metal.metal in Sources */,
|
||||
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */,
|
||||
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
|
||||
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
|
||||
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
|
||||
7FCB08262ACFA3A400AF3530 /* ggml-metal.m in Sources */,
|
||||
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
|
||||
18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
@ -401,7 +348,7 @@
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = P8JZH34X63;
|
||||
DEVELOPMENT_TEAM = "";
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
@ -425,7 +372,6 @@
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_OBJC_BRIDGING_HEADER = "whisper.cpp.swift/WhisperCppDemo-Bridging-Header.h";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
@ -442,7 +388,7 @@
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = P8JZH34X63;
|
||||
DEVELOPMENT_TEAM = "";
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
@ -471,7 +417,6 @@
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_OBJC_BRIDGING_HEADER = "whisper.cpp.swift/WhisperCppDemo-Bridging-Header.h";
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
@ -499,6 +444,13 @@
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
E3F92DC42AFA8E3800A6A9D4 /* whisper */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
productName = whisper;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = 0AAC5D8F29539CCF003032C3 /* Project object */;
|
||||
}
|
||||
|
@ -18,11 +18,11 @@ else
|
||||
fi
|
||||
|
||||
models=( \
|
||||
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
||||
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
|
||||
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
|
||||
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
|
||||
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
|
||||
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
||||
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
||||
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
|
||||
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
|
||||
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
|
||||
)
|
||||
|
||||
if [ "$encoder_only" -eq 0 ]; then
|
||||
@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
|
||||
printf "\n"
|
||||
fi
|
||||
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
# actual run
|
||||
@ -56,6 +56,7 @@ for model in "${models[@]}"; do
|
||||
# parse the output:
|
||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
||||
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
||||
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
|
||||
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
||||
system_info=$(echo "$output" | grep "system_info")
|
||||
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
||||
@ -83,6 +84,10 @@ for model in "${models[@]}"; do
|
||||
config="$config COREML"
|
||||
fi
|
||||
|
||||
if [[ $system_info == *"CUDA = 1"* ]]; then
|
||||
config="$config CUDA"
|
||||
fi
|
||||
|
||||
if [[ $system_info == *"METAL = 1"* ]]; then
|
||||
config="$config METAL"
|
||||
fi
|
||||
@ -90,6 +95,6 @@ for model in "${models[@]}"; do
|
||||
commit=$(git rev-parse --short HEAD)
|
||||
|
||||
if [ $ret -eq 0 ]; then
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
||||
fi
|
||||
done
|
||||
|
@ -15,33 +15,13 @@ declare -a filedex
|
||||
cd `dirname $0`
|
||||
cd ../
|
||||
|
||||
# Let's loop across all the objects in the 'models' dir:
|
||||
for i in ./models/*; do
|
||||
# Check to see if it's a file or directory
|
||||
if [ -d "$i" ]; then
|
||||
# It's a directory! We should make sure it's not empty first:
|
||||
if [ "$(ls -A $i)" ]; then
|
||||
# Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
|
||||
for f in "$i"/*.bin; do
|
||||
# [Neuron Activation]
|
||||
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
|
||||
if [ "$newfile" != "q5" ]; then
|
||||
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
|
||||
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
|
||||
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
|
||||
fi
|
||||
done
|
||||
fi
|
||||
else
|
||||
# It's a file! Let's make sure it's the right type:
|
||||
if [ "${i##*.}" == "bin" ]; then
|
||||
# And we probably want to skip the testing files
|
||||
if [ "${i:9:8}" != "for-test" ]; then
|
||||
# [Neuron Activation]
|
||||
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
|
||||
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
|
||||
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
|
||||
fi
|
||||
for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
|
||||
m="models/$i"
|
||||
if [ -f "$m" ]; then
|
||||
if [ "${m##*.}" == "bin" ]; then
|
||||
./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
|
||||
./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
|
||||
filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
407
ggml-cuda.cu
@ -39,7 +39,6 @@
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceGetMemPool hipDeviceGetMemPool
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
@ -49,7 +48,6 @@
|
||||
#define cudaEvent_t hipEvent_t
|
||||
#define cudaEventDestroy hipEventDestroy
|
||||
#define cudaFree hipFree
|
||||
#define cudaFreeAsync hipFreeAsync
|
||||
#define cudaFreeHost hipHostFree
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaGetDeviceCount hipGetDeviceCount
|
||||
@ -57,7 +55,6 @@
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocFromPoolAsync hipMallocFromPoolAsync
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
#define cudaMemcpy hipMemcpy
|
||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||
@ -66,9 +63,6 @@
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemPool_t hipMemPool_t
|
||||
#define cudaMemPoolAttrReleaseThreshold hipMemPoolAttrReleaseThreshold
|
||||
#define cudaMemPoolSetAttribute hipMemPoolSetAttribute
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
@ -94,6 +88,8 @@
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
|
||||
|
||||
#define GGML_CUDA_MAX_NODES 8192
|
||||
|
||||
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
||||
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
||||
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
||||
@ -188,11 +184,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
int dev_id; \
|
||||
cudaGetDevice(&dev_id); \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
fprintf(stderr, "current device: %d\n", dev_id); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
@ -202,11 +198,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
int dev_id; \
|
||||
cudaGetDevice(&dev_id); \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
||||
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
||||
fprintf(stderr, "current device: %d\n", dev_id); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
@ -440,6 +436,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_RELU_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_CPY_BLOCK_SIZE 32
|
||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||
#define CUDA_CLAMP_BLOCK_SIZE 256
|
||||
@ -472,7 +470,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
||||
|
||||
#define MAX_STREAMS 8
|
||||
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
|
||||
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
|
||||
|
||||
struct ggml_tensor_extra_gpu {
|
||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||
@ -561,6 +558,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void relu_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = fmaxf(x[i], 0);
|
||||
}
|
||||
|
||||
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -990,7 +1005,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -1094,7 +1109,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -1198,7 +1213,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
const int ib0 = row*num_blocks_per_row;
|
||||
@ -1452,7 +1467,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
const int num_blocks_per_row = ncols / QK_K;
|
||||
@ -4262,7 +4277,7 @@ template <bool need_check> static __global__ void
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
@ -4302,7 +4317,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
@ -4476,6 +4491,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||
*dsti = __float2half(*xi);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||
const half * xi = (const half *) cxi;
|
||||
half * dsti = (half *) cdsti;
|
||||
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
@ -4729,6 +4751,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||
}
|
||||
|
||||
static __global__ void im2col_f32_f16(
|
||||
const float * x, half * dst,
|
||||
int ofs0, int ofs1, int IW, int IH, int CHW,
|
||||
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
|
||||
const int offset_dst =
|
||||
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = __float2half(0.0f);
|
||||
} else {
|
||||
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
}
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
@ -4767,6 +4808,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
|
||||
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
@ -4875,7 +4926,8 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4884,7 +4936,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4893,7 +4945,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4902,7 +4954,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4911,7 +4963,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
|
||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -4921,7 +4973,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -4930,7 +4982,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -4939,7 +4991,7 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -4954,7 +5006,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
@ -4962,7 +5014,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -4971,7 +5023,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK4_1 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -4980,7 +5032,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK5_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -4989,7 +5041,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK5_1 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -4998,7 +5050,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK8_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5007,7 +5059,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5016,7 +5068,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5025,7 +5077,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5034,7 +5086,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5043,7 +5095,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
@ -5062,7 +5114,7 @@ static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cu
|
||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(1, block_num_y, 1);
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
@ -5618,6 +5670,16 @@ static void ggml_cpy_f32_f16_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f16_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
|
||||
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
|
||||
}
|
||||
|
||||
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
|
||||
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
|
||||
@ -5701,6 +5763,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
|
||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||
}
|
||||
|
||||
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
||||
int OH, int IW, int IH, int OW, int IC,
|
||||
int KH, int KW, int N, int ofs0, int ofs1,
|
||||
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
|
||||
dim3 block_nums(IC, OH, OW);
|
||||
dim3 block_dims(N, KH, KW);
|
||||
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
@ -5780,16 +5851,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
|
||||
if (g_cudaMemPools[id] == nullptr) {
|
||||
return ggml_cuda_pool_malloc(size, actual_size);
|
||||
}
|
||||
void *ptr;
|
||||
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
|
||||
*actual_size = size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
int id;
|
||||
@ -5807,12 +5868,10 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
|
||||
static bool g_cublas_loaded = false;
|
||||
|
||||
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
|
||||
if (g_cudaMemPools[id] == nullptr) {
|
||||
return ggml_cuda_pool_free(ptr, actual_size);
|
||||
}
|
||||
CUDA_CHECK(cudaFreeAsync(ptr, stream));
|
||||
bool ggml_cublas_loaded(void) {
|
||||
return g_cublas_loaded;
|
||||
}
|
||||
|
||||
void ggml_init_cublas() {
|
||||
@ -5827,7 +5886,12 @@ void ggml_init_cublas() {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
#endif
|
||||
|
||||
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
|
||||
if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
|
||||
initialized = true;
|
||||
g_cublas_loaded = false;
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
|
||||
int64_t total_vram = 0;
|
||||
#if defined(GGML_CUDA_FORCE_MMQ)
|
||||
@ -5869,19 +5933,13 @@ void ggml_init_cublas() {
|
||||
// create cublas handle
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
|
||||
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
|
||||
|
||||
// configure memory pool
|
||||
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
|
||||
if (err == cudaSuccess) {
|
||||
size_t treshold = UINT64_MAX;
|
||||
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
|
||||
}
|
||||
}
|
||||
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||
|
||||
initialized = true;
|
||||
g_cublas_loaded = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -6148,6 +6206,34 @@ inline void ggml_cuda_op_silu(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_relu(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_sqr(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_norm(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -6469,7 +6555,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
size_t ne = row_diff*ne00;
|
||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
|
||||
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
|
||||
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
|
||||
}
|
||||
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
|
||||
@ -6480,12 +6566,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
size_t ne = src1_ncols*ne10;
|
||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
|
||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
||||
size_t dst_f16_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
@ -6503,15 +6589,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
||||
|
||||
if (dst_f16_as != 0) {
|
||||
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
|
||||
}
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
|
||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||
}
|
||||
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
}
|
||||
}
|
||||
else {
|
||||
@ -6521,7 +6606,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
|
||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
||||
}
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
||||
@ -6538,7 +6623,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
&beta, dst_dd_i, ldc));
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
|
||||
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
|
||||
}
|
||||
}
|
||||
|
||||
@ -6659,6 +6744,45 @@ inline void ggml_cuda_op_alibi(
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_im2col(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||
|
||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
||||
|
||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||
|
||||
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
||||
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int64_t IW = src1->ne[0];
|
||||
|
||||
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const int64_t KW = src0->ne[0];
|
||||
|
||||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int64_t OW = dst->ne[1];
|
||||
|
||||
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
||||
OH, IW, IH, OW, IC, KH, KW, N,
|
||||
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
|
||||
(void) src0;
|
||||
(void) src0_dd;
|
||||
}
|
||||
|
||||
inline void ggml_cuda_op_diag_mask_inf(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
|
||||
@ -6924,6 +7048,8 @@ static void ggml_cuda_op_mul_mat(
|
||||
int64_t row_low[GGML_CUDA_MAX_DEVICES];
|
||||
int64_t row_high[GGML_CUDA_MAX_DEVICES];
|
||||
|
||||
int used_devices = 0;
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
// by default, use all rows
|
||||
row_low[id] = 0;
|
||||
@ -6951,6 +7077,8 @@ static void ggml_cuda_op_mul_mat(
|
||||
continue;
|
||||
}
|
||||
|
||||
used_devices++;
|
||||
|
||||
const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device;
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device;
|
||||
|
||||
@ -6961,22 +7089,21 @@ static void ggml_cuda_op_mul_mat(
|
||||
src0_dd[id] = (char *) src0_extra->data_device[id];
|
||||
} else {
|
||||
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
|
||||
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
|
||||
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
|
||||
}
|
||||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
||||
} else {
|
||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
|
||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
|
||||
}
|
||||
|
||||
if (convert_src1_to_q8_1) {
|
||||
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
|
||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
|
||||
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
|
||||
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
|
||||
// CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
@ -6984,18 +7111,18 @@ static void ggml_cuda_op_mul_mat(
|
||||
dst_dd[id] = (float *) dst_extra->data_device[id];
|
||||
} else {
|
||||
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
|
||||
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
|
||||
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
|
||||
}
|
||||
}
|
||||
|
||||
// if multiple devices are used they need to wait for the main device
|
||||
// here an event is recorded that signals that the main device has finished calculating the input data
|
||||
if (split && g_device_count > 1) {
|
||||
if (split && used_devices > 1) {
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0]));
|
||||
}
|
||||
|
||||
const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
||||
const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
|
||||
for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
|
||||
const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0;
|
||||
const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
|
||||
@ -7110,6 +7237,27 @@ static void ggml_cuda_op_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if ((!split && id != g_main_device) || row_low[id] == row_high[id]) {
|
||||
continue;
|
||||
}
|
||||
CUDA_CHECK(ggml_cuda_set_device(id));
|
||||
|
||||
// free buffers again when done
|
||||
if (src0_as[id] > 0) {
|
||||
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
|
||||
}
|
||||
if (src1_asf[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
|
||||
}
|
||||
if (src1_asq[id] > 0) {
|
||||
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
|
||||
}
|
||||
if (dst_as[id] > 0) {
|
||||
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
|
||||
}
|
||||
}
|
||||
|
||||
// main device waits for all other devices to be finished
|
||||
if (split && g_device_count > 1) {
|
||||
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
|
||||
@ -7117,6 +7265,9 @@ static void ggml_cuda_op_mul_mat(
|
||||
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (row_low[id] == row_high[id]) {
|
||||
continue;
|
||||
}
|
||||
for (int64_t is = 0; is < is_max; ++is) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0));
|
||||
}
|
||||
@ -7127,21 +7278,6 @@ static void ggml_cuda_op_mul_mat(
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (src0_as[id] > 0) {
|
||||
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
if (src1_asf[id] > 0) {
|
||||
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
if (src1_asq[id] > 0) {
|
||||
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
if (dst_as[id] > 0) {
|
||||
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@ -7168,6 +7304,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
|
||||
}
|
||||
|
||||
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
|
||||
}
|
||||
|
||||
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
|
||||
}
|
||||
@ -7177,6 +7321,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
|
||||
}
|
||||
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
if (!g_cublas_loaded) return false;
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
@ -7328,11 +7474,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
|
||||
size_t src1_as = 0;
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
@ -7386,8 +7532,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
size_t ptrs_src_s = 0;
|
||||
size_t ptrs_dst_s = 0;
|
||||
|
||||
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
|
||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
|
||||
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
|
||||
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
@ -7400,6 +7546,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
dst->nb[2], dst->nb[3],
|
||||
r2, r3);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
@ -7411,30 +7558,29 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
if (ptrs_src_s != 0) {
|
||||
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
|
||||
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
|
||||
}
|
||||
if (ptrs_dst_s != 0) {
|
||||
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
|
||||
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
|
||||
}
|
||||
if (dst_as != 0) {
|
||||
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
|
||||
}
|
||||
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool all_on_device =
|
||||
(src0->backend == GGML_BACKEND_GPU) &&
|
||||
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
||||
(src1->backend == GGML_BACKEND_GPU) &&
|
||||
( dst->backend == GGML_BACKEND_GPU);
|
||||
|
||||
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
for (int64_t id = 0; id < g_device_count; ++id) {
|
||||
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
|
||||
@ -7456,13 +7602,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
@ -7549,6 +7695,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
|
||||
ne10, ne11, nb10, nb11, nb12, main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@ -7580,6 +7729,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
|
||||
}
|
||||
|
||||
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
|
||||
}
|
||||
|
||||
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
(void) src0;
|
||||
(void) src1;
|
||||
@ -7691,11 +7844,11 @@ static size_t g_temp_tensor_extra_index = 0;
|
||||
|
||||
static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
if (g_temp_tensor_extras == nullptr) {
|
||||
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
|
||||
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
|
||||
}
|
||||
|
||||
size_t alloc_index = g_temp_tensor_extra_index;
|
||||
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
|
||||
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
|
||||
ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
@ -7862,6 +8015,8 @@ void ggml_cuda_free_scratch() {
|
||||
}
|
||||
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
if (!g_cublas_loaded) return false;
|
||||
|
||||
ggml_cuda_func_t func;
|
||||
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
||||
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|
||||
@ -7871,6 +8026,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT) {
|
||||
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||
#ifndef NDEBUG
|
||||
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_REPEAT:
|
||||
func = ggml_cuda_repeat;
|
||||
@ -7895,6 +8059,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_UNARY_OP_SILU:
|
||||
func = ggml_cuda_silu;
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
func = ggml_cuda_relu;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
} break;
|
||||
@ -7913,6 +8080,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_SCALE:
|
||||
func = ggml_cuda_scale;
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
func = ggml_cuda_sqr;
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
@ -7943,6 +8113,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||
case GGML_OP_ALIBI:
|
||||
func = ggml_cuda_alibi;
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
func = ggml_cuda_im2col;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -8002,11 +8175,11 @@ struct ggml_backend_buffer_context_cuda {
|
||||
|
||||
ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
|
||||
if (temp_tensor_extras == nullptr) {
|
||||
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_DEFAULT_GRAPH_SIZE];
|
||||
temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES];
|
||||
}
|
||||
|
||||
size_t alloc_index = temp_tensor_extra_index;
|
||||
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_DEFAULT_GRAPH_SIZE;
|
||||
temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_CUDA_MAX_NODES;
|
||||
ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index];
|
||||
memset(extra, 0, sizeof(*extra));
|
||||
|
||||
|
@ -17,7 +17,12 @@ extern "C" {
|
||||
|
||||
#define GGML_CUDA_MAX_DEVICES 16
|
||||
|
||||
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
|
||||
GGML_API void ggml_init_cublas(void);
|
||||
|
||||
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
|
||||
GGML_API bool ggml_cublas_loaded(void);
|
||||
|
||||
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
||||
GGML_API void ggml_cuda_host_free(void * ptr);
|
||||
|
||||
|
@ -26,7 +26,7 @@
|
||||
#include <stdbool.h>
|
||||
|
||||
// max memory buffers that can be mapped to the device
|
||||
#define GGML_METAL_MAX_BUFFERS 16
|
||||
#define GGML_METAL_MAX_BUFFERS 64
|
||||
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
||||
|
||||
struct ggml_tensor;
|
||||
|
89
ggml-metal.m
@ -86,6 +86,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
||||
@ -114,6 +115,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||
@ -126,7 +128,7 @@ struct ggml_metal_context {
|
||||
// MSL code
|
||||
// TODO: move the contents here when ready
|
||||
// for now it is easier to work in a separate file
|
||||
static NSString * const msl_library_source = @"see metal.metal";
|
||||
//static NSString * const msl_library_source = @"see metal.metal";
|
||||
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
@ -142,7 +144,8 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
|
||||
ggml_metal_log_user_data = user_data;
|
||||
}
|
||||
|
||||
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
||||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
||||
if (ggml_metal_log_callback != NULL) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
@ -287,6 +290,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
||||
@ -317,6 +321,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||
@ -335,7 +340,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
||||
if ([ctx->device supportsFamily:i]) {
|
||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -386,6 +391,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
||||
@ -416,6 +422,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||
@ -473,6 +480,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
||||
|
||||
const int64_t tsize = ggml_nbytes(t);
|
||||
|
||||
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
|
||||
ctx = t->buffer->backend->context;
|
||||
}
|
||||
|
||||
// find the view that contains the tensor fully
|
||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||
@ -573,7 +584,7 @@ bool ggml_metal_add_buffer(
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
|
||||
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
||||
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||
} else {
|
||||
GGML_METAL_LOG_INFO("\n");
|
||||
}
|
||||
@ -1139,6 +1150,7 @@ void ggml_metal_graph_compute(
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
||||
nrows = 4;
|
||||
} break;
|
||||
@ -1146,13 +1158,18 @@ void ggml_metal_graph_compute(
|
||||
{
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
if (src1t == GGML_TYPE_F32) {
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||
nrows = 4;
|
||||
}
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
||||
nrows = 4;
|
||||
}
|
||||
} break;
|
||||
@ -1464,6 +1481,58 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||
|
||||
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
||||
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int32_t IW = src1->ne[0];
|
||||
|
||||
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const int32_t KW = src0->ne[0];
|
||||
|
||||
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int32_t OW = dst->ne[1];
|
||||
|
||||
const int32_t CHW = IC * KH * KW;
|
||||
|
||||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
||||
default: GGML_ASSERT(false);
|
||||
};
|
||||
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
108
ggml-metal.metal
@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F32_F32;
|
||||
@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
#define N_F16_F16 4
|
||||
|
||||
kernel void kernel_mul_mv_f16_f16(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F16_F16;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (half) x[i] * (half) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *)x;
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||
device const half4 * y4 = (device const half4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@ -1229,6 +1302,39 @@ kernel void kernel_rope(
|
||||
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||
|
||||
kernel void kernel_im2col_f16(
|
||||
device const float * x,
|
||||
device half * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
||||
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
||||
|
||||
const int32_t offset_dst =
|
||||
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
||||
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
|
19
ggml.h
@ -403,13 +403,8 @@ extern "C" {
|
||||
GGML_OP_ROPE_BACK,
|
||||
GGML_OP_ALIBI,
|
||||
GGML_OP_CLAMP,
|
||||
GGML_OP_CONV_1D,
|
||||
GGML_OP_CONV_1D_STAGE_0, // internal
|
||||
GGML_OP_CONV_1D_STAGE_1, // internal
|
||||
GGML_OP_CONV_TRANSPOSE_1D,
|
||||
GGML_OP_CONV_2D,
|
||||
GGML_OP_CONV_2D_STAGE_0, // internal
|
||||
GGML_OP_CONV_2D_STAGE_1, // internal
|
||||
GGML_OP_IM2COL,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
@ -1398,6 +1393,18 @@ extern "C" {
|
||||
float min,
|
||||
float max);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_im2col(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1,
|
||||
bool is_2D);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
57
grammars/assistant.gbnf
Normal file
@ -0,0 +1,57 @@
|
||||
# - "turn on lights."
|
||||
# - "set thermostat to 22."
|
||||
# - "increase TV by 10."
|
||||
# - "decrease oven by 50."
|
||||
# - "play music."
|
||||
# - "stop podcast."
|
||||
# - "schedule cleaning at 3pm."
|
||||
# - "cancel cleaning."
|
||||
# - "remind me to buy milk at 5pm."
|
||||
# - "show me security system."
|
||||
# - "hide washing machine."
|
||||
# - "what is the lights status?"
|
||||
# - "what is the current thermostat value?"
|
||||
# - "what is the security system status?"
|
||||
# - "what is the door lock status?"
|
||||
# - "what is the camera battery level?"
|
||||
# - "what is the weather like today?"
|
||||
# - "what is the forecast for tomorrow?"
|
||||
# - "what is the time?"
|
||||
# - "what is my schedule for today?"
|
||||
# - "what tasks do I have?"
|
||||
# - "what reminders do I have?"
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
|
||||
#
|
||||
|
||||
root ::= init " " (command | question) "."
|
||||
prompt ::= init
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " Ok Whisper, start listening for commands."
|
||||
|
||||
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
|
||||
"Increase " device " by " value | "Decrease " device " by " value |
|
||||
"Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
|
||||
"Remind me to " task " at " time | "Show me " device | "Hide " device
|
||||
|
||||
question ::= "What is the " device " status?" | "What is the current " device " value?" |
|
||||
"What is the " device " temperature?" | "What is the " device " humidity?" |
|
||||
"What is the " device " power consumption?" | "What is the " device " battery level?" |
|
||||
"What is the weather like today?" | "What is the forecast for tomorrow?" |
|
||||
"What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
|
||||
"What reminders do I have?"
|
||||
|
||||
device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
|
||||
"music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
|
||||
"vacuum cleaner"
|
||||
|
||||
value ::= [0-9]+
|
||||
|
||||
media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
|
||||
|
||||
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
|
||||
|
||||
time ::= [0-9] [0-9]? ("am" | "pm")?
|
29
grammars/chess.gbnf
Normal file
@ -0,0 +1,29 @@
|
||||
# - bishop to c3
|
||||
# - rook to d4
|
||||
# - knight to e5
|
||||
# - d4 d5 knight to c3
|
||||
# - c3 queen to d4 king b1
|
||||
# - pawn to a1 bishop to b2 knight to c3
|
||||
#
|
||||
# The prompt (--prompt) is the initial phrase that the user has to say.
|
||||
# This is used to prime Whisper with how the user is expected to speak.
|
||||
#
|
||||
# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
|
||||
# Longer context is better, but it slightly increases the processing time.
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
|
||||
#
|
||||
|
||||
root ::= init move move? move? "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " rook to b4, f3"
|
||||
|
||||
move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
||||
|
||||
piece ::= "bishop" | "rook" | "knight" | "queen"
|
||||
king ::= "king"
|
||||
pawn ::= "pawn"
|
16
grammars/colors.gbnf
Normal file
@ -0,0 +1,16 @@
|
||||
# - red
|
||||
# - green
|
||||
# - blue
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
|
||||
#
|
||||
|
||||
root ::= init color "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " red, green, blue"
|
||||
|
||||
color ::= ", " ("red" | "green" | "blue")
|
@ -252,7 +252,7 @@ class WhisperANE(Whisper):
|
||||
def convert_encoder(hparams, model, quantize=False):
|
||||
model.eval()
|
||||
|
||||
input_shape = (1, 80, 3000)
|
||||
input_shape = (1, hparams.n_mels, 3000)
|
||||
input_data = torch.randn(input_shape)
|
||||
traced_model = torch.jit.trace(model, input_data)
|
||||
|
||||
@ -302,7 +302,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
||||
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
whisper = load_model(args.model).cpu()
|
||||
|
@ -9,7 +9,7 @@ import shutil
|
||||
def convert_encoder(hparams, encoder, mname):
|
||||
encoder.eval()
|
||||
|
||||
mel = torch.zeros((1, 80, 3000))
|
||||
mel = torch.zeros((1, hparams.n_mels, 3000))
|
||||
|
||||
onnx_folder=os.path.join(os.path.dirname(__file__),"onnx_encoder")
|
||||
|
||||
|
1
spm-headers/ggml.h
Symbolic link
@ -0,0 +1 @@
|
||||
../ggml.h
|
1
spm-headers/whisper.h
Symbolic link
@ -0,0 +1 @@
|
||||
../whisper.h
|
2450
whisper.cpp
58
whisper.h
@ -1,6 +1,8 @@
|
||||
#ifndef WHISPER_H
|
||||
#define WHISPER_H
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
@ -76,7 +78,9 @@ extern "C" {
|
||||
struct whisper_state;
|
||||
struct whisper_full_params;
|
||||
|
||||
typedef int whisper_token;
|
||||
typedef int32_t whisper_pos;
|
||||
typedef int32_t whisper_token;
|
||||
typedef int32_t whisper_seq_id;
|
||||
|
||||
struct whisper_context_params {
|
||||
bool use_gpu;
|
||||
@ -107,18 +111,49 @@ extern "C" {
|
||||
void (*close)(void * ctx);
|
||||
} whisper_model_loader;
|
||||
|
||||
// grammar element type
|
||||
enum whisper_gretype {
|
||||
// end of rule definition
|
||||
WHISPER_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
WHISPER_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
WHISPER_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
WHISPER_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
WHISPER_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or
|
||||
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
WHISPER_GRETYPE_CHAR_ALT = 6,
|
||||
};
|
||||
|
||||
typedef struct whisper_grammar_element {
|
||||
enum whisper_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} whisper_grammar_element;
|
||||
|
||||
// Various functions for loading a ggml whisper model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params);
|
||||
|
||||
// These are the same as the above, but the internal state of the context is not allocated automatically
|
||||
// It is the responsibility of the caller to allocate the state using whisper_init_state() (#523)
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);
|
||||
WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params);
|
||||
|
||||
WHISPER_DEPRECATED(
|
||||
WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model),
|
||||
@ -400,6 +435,7 @@ extern "C" {
|
||||
|
||||
bool translate;
|
||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
bool no_timestamps; // do not generate timestamps
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
@ -477,6 +513,11 @@ extern "C" {
|
||||
// called by each decoder to filter obtained logits
|
||||
whisper_logits_filter_callback logits_filter_callback;
|
||||
void * logits_filter_callback_user_data;
|
||||
|
||||
const whisper_grammar_element ** grammar_rules;
|
||||
size_t n_grammar_rules;
|
||||
size_t i_start_rule;
|
||||
float grammar_penalty;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params()
|
||||
@ -570,8 +611,7 @@ extern "C" {
|
||||
|
||||
// Control logging output; default behavior is to print to stderr
|
||||
|
||||
typedef void (*whisper_log_callback)(const char * line);
|
||||
WHISPER_API void whisper_set_log_callback(whisper_log_callback callback);
|
||||
WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|