Compare commits
2 Commits
cuda-cubla
...
ggml-backe
Author | SHA1 | Date | |
---|---|---|---|
bf4110dbcf | |||
005b8ccbf0 |
44
.github/workflows/build.yml
vendored
@ -25,7 +25,6 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential libsdl2-dev
|
||||
make
|
||||
@ -87,7 +86,6 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
@ -115,10 +113,8 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y clang
|
||||
apt install -y clang build-essential cmake libsdl2-dev
|
||||
apt install -y build-essential cmake libsdl2-dev
|
||||
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
|
||||
make
|
||||
ctest -L gh --output-on-failure'
|
||||
@ -144,7 +140,6 @@ jobs:
|
||||
docker run --platform ${{ matrix.arch }} --rm \
|
||||
-v ${{ github.workspace }}:/workspace \
|
||||
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
|
||||
set -e
|
||||
apt update
|
||||
apt install -y build-essential cmake
|
||||
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
|
||||
@ -222,10 +217,10 @@ jobs:
|
||||
sdl2: [ON]
|
||||
include:
|
||||
- arch: Win32
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
|
||||
s2arc: x86
|
||||
- arch: x64
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
|
||||
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
|
||||
s2arc: x64
|
||||
- sdl2: ON
|
||||
s2ver: 2.26.0
|
||||
@ -325,13 +320,6 @@ jobs:
|
||||
cd ./build
|
||||
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
|
||||
|
||||
- name: Copy CUDA DLLs
|
||||
run: >
|
||||
Copy-Item -PassThru
|
||||
-Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
|
||||
-Include cudart64_*,cublas64_*,cublasLt64_*
|
||||
-Destination build/bin/${{ matrix.build }}
|
||||
|
||||
- name: Copy SDL2.dll
|
||||
if: matrix.sdl2 == 'ON'
|
||||
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
|
||||
@ -408,32 +396,6 @@ jobs:
|
||||
cd examples/whisper.android
|
||||
./gradlew assembleRelease --no-daemon
|
||||
|
||||
android_java:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: set up JDK 11
|
||||
uses: actions/setup-java@v3
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'temurin'
|
||||
cache: gradle
|
||||
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v2
|
||||
with:
|
||||
api-level: 30
|
||||
build-tools-version: 30.0.3
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd examples/whisper.android.java
|
||||
chmod +x ./gradlew
|
||||
./gradlew assembleRelease
|
||||
|
||||
java:
|
||||
needs: [ 'windows' ]
|
||||
runs-on: windows-latest
|
||||
|
6
.gitignore
vendored
@ -8,7 +8,6 @@
|
||||
.DS_Store
|
||||
|
||||
build/
|
||||
build-coreml/
|
||||
build-em/
|
||||
build-debug/
|
||||
build-release/
|
||||
@ -31,7 +30,6 @@ build-sanitize-thread/
|
||||
/talk-llama
|
||||
/bench
|
||||
/quantize
|
||||
/server
|
||||
/lsp
|
||||
|
||||
arm_neon.h
|
||||
@ -55,7 +53,3 @@ bindings/java/.idea/
|
||||
.idea/
|
||||
|
||||
benchmark_results.csv
|
||||
cmake-build-debug/
|
||||
.cxx/
|
||||
.gradle/
|
||||
local.properties
|
@ -1,6 +1,6 @@
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
project(whisper.cpp VERSION 1.5.1)
|
||||
project(whisper.cpp VERSION 1.4.3)
|
||||
|
||||
# Add path to modules
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
53
Makefile
@ -1,4 +1,4 @@
|
||||
default: main bench quantize server
|
||||
default: main bench quantize
|
||||
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
@ -307,7 +307,7 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
|
||||
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
|
||||
WHISPER_OBJ += ggml-alloc.o ggml-backend.o ggml-quants.o
|
||||
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
@ -331,14 +331,14 @@ ggml-metal.o: ggml-metal.m ggml-metal.h
|
||||
WHISPER_OBJ += ggml-metal.o
|
||||
endif
|
||||
|
||||
libwhisper.a: $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a $(WHISPER_OBJ)
|
||||
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
||||
|
||||
libwhisper.so: $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
|
||||
libwhisper.so: ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
|
||||
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
|
||||
|
||||
#
|
||||
# Examples
|
||||
@ -349,33 +349,30 @@ CC_SDL=`sdl2-config --cflags --libs`
|
||||
SRC_COMMON = examples/common.cpp examples/common-ggml.cpp
|
||||
SRC_COMMON_SDL = examples/common-sdl.cpp
|
||||
|
||||
main: examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o main $(LDFLAGS)
|
||||
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
|
||||
./main -h
|
||||
|
||||
bench: examples/bench/bench.cpp $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
||||
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
|
||||
|
||||
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o quantize $(LDFLAGS)
|
||||
|
||||
server: examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o server $(LDFLAGS)
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk-llama $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
#
|
||||
# Audio samples
|
||||
@ -421,9 +418,9 @@ samples:
|
||||
.PHONY: medium
|
||||
.PHONY: large-v1
|
||||
.PHONY: large-v2
|
||||
.PHONY: large-v3
|
||||
.PHONY: large
|
||||
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3: main
|
||||
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large: main
|
||||
bash ./models/download-ggml-model.sh $@
|
||||
@echo ""
|
||||
@echo "==============================================="
|
||||
|
35
README.md
@ -6,7 +6,7 @@
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://www.npmjs.com/package/whisper.cpp/)
|
||||
|
||||
Stable: [v1.5.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
Beta: [v1.4.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.3) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
|
||||
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
@ -16,10 +16,12 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Support for CPU-only inference
|
||||
- [Efficient GPU support for NVIDIA](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
||||
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
|
||||
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
|
||||
|
||||
@ -34,8 +36,10 @@ Supported platforms:
|
||||
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
|
||||
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
|
||||
|
||||
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
|
||||
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
|
||||
The entire implementation of the model is contained in 2 source files:
|
||||
|
||||
- Tensor operations: [ggml.h](ggml.h) / [ggml.c](ggml.c)
|
||||
- Transformer inference: [whisper.h](whisper.h) / [whisper.cpp](whisper.cpp)
|
||||
|
||||
Having such a lightweight implementation of the model allows to easily integrate it in different platforms and applications.
|
||||
As an example, here is a video of running the model on an iPhone 13 device - fully offline, on-device: [whisper.objc](examples/whisper.objc)
|
||||
@ -231,18 +235,18 @@ make medium.en
|
||||
make medium
|
||||
make large-v1
|
||||
make large-v2
|
||||
make large-v3
|
||||
make large
|
||||
```
|
||||
|
||||
## Memory usage
|
||||
|
||||
| Model | Disk | Mem |
|
||||
| --- | --- | --- |
|
||||
| tiny | 75 MiB | ~273 MB |
|
||||
| base | 142 MiB | ~388 MB |
|
||||
| small | 466 MiB | ~852 MB |
|
||||
| medium | 1.5 GiB | ~2.1 GB |
|
||||
| large | 2.9 GiB | ~3.9 GB |
|
||||
| Model | Disk | Mem | SHA |
|
||||
| --- | --- | --- | --- |
|
||||
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
|
||||
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
|
||||
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
|
||||
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
|
||||
| large | 2.9 GB | ~3.3 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
|
||||
|
||||
## Quantization
|
||||
|
||||
@ -396,12 +400,12 @@ This can result in significant speedup in encoder performance. Here are the inst
|
||||
|
||||
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
|
||||
cached for the next run.
|
||||
|
||||
|
||||
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
|
||||
|
||||
## NVIDIA GPU support
|
||||
## NVIDIA GPU support via cuBLAS
|
||||
|
||||
With NVIDIA cards the processing of the models is done efficiently on the GPU via cuBLAS and custom CUDA kernels.
|
||||
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
|
||||
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
|
||||
|
||||
Now build `whisper.cpp` with cuBLAS support:
|
||||
@ -777,7 +781,6 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
|
||||
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
|
||||
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
|
||||
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
|
||||
| [server](examples/server) | | HTTP transcription server with OAI-like API |
|
||||
|
||||
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)
|
||||
|
||||
|
@ -1,26 +1,9 @@
|
||||
ifndef UNAME_S
|
||||
UNAME_S := $(shell uname -s)
|
||||
endif
|
||||
|
||||
ifndef UNAME_P
|
||||
UNAME_P := $(shell uname -p)
|
||||
endif
|
||||
|
||||
ifndef UNAME_M
|
||||
UNAME_M := $(shell uname -m)
|
||||
endif
|
||||
|
||||
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
|
||||
BUILD_DIR := build
|
||||
MODELS_DIR := models
|
||||
EXAMPLES_DIR := $(wildcard examples/*)
|
||||
INCLUDE_PATH := $(abspath ../..)
|
||||
LIBRARY_PATH := $(abspath ../..)
|
||||
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
|
||||
endif
|
||||
|
||||
all: clean whisper examples
|
||||
|
||||
whisper: mkdir
|
||||
@ -28,13 +11,8 @@ whisper: mkdir
|
||||
@${MAKE} -C ../.. libwhisper.a
|
||||
|
||||
test: model-small whisper modtidy
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
|
||||
endif
|
||||
|
||||
examples: $(EXAMPLES_DIR)
|
||||
|
||||
@ -43,11 +21,7 @@ model-small: mkdir examples/go-model-download
|
||||
|
||||
$(EXAMPLES_DIR): mkdir whisper modtidy
|
||||
@echo Build example $(notdir $@)
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
else
|
||||
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
|
||||
endif
|
||||
|
||||
mkdir:
|
||||
@echo Mkdir ${BUILD_DIR}
|
||||
|
@ -24,7 +24,7 @@ const (
|
||||
|
||||
var (
|
||||
// The models which will be downloaded, if no model is specified as an argument
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3"}
|
||||
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large"}
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -9,7 +9,6 @@ archivesBaseName = 'whispercpp'
|
||||
group = 'io.github.ggerganov'
|
||||
version = '1.4.0'
|
||||
|
||||
|
||||
sourceCompatibility = 1.8
|
||||
targetCompatibility = 1.8
|
||||
|
||||
|
@ -2,7 +2,6 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import com.sun.jna.Native;
|
||||
import com.sun.jna.Pointer;
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -10,8 +9,6 @@ import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
|
||||
@ -163,28 +160,6 @@ public class WhisperCpp implements AutoCloseable {
|
||||
|
||||
return str.toString().trim();
|
||||
}
|
||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||
if (ctx == null) {
|
||||
throw new IllegalStateException("Model not initialised");
|
||||
}
|
||||
|
||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||
throw new IOException("Failed to process audio");
|
||||
}
|
||||
|
||||
int nSegments = lib.whisper_full_n_segments(ctx);
|
||||
List<WhisperSegment> segments= new ArrayList<>(nSegments);
|
||||
|
||||
|
||||
for (int i = 0; i < nSegments; i++) {
|
||||
long t0 = lib.whisper_full_get_segment_t0(ctx, i);
|
||||
String text = lib.whisper_full_get_segment_text(ctx, i);
|
||||
long t1 = lib.whisper_full_get_segment_t1(ctx, i);
|
||||
segments.add(new WhisperSegment(t0,t1,text));
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
|
||||
// public int getTextSegmentCount(Pointer ctx) {
|
||||
// return lib.whisper_full_n_segments(ctx);
|
||||
|
@ -1,47 +0,0 @@
|
||||
package io.github.ggerganov.whispercpp.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "[" + start + " --> " + end + "]:" + sentence;
|
||||
}
|
||||
}
|
@ -58,9 +58,6 @@ public class WhisperFullParams extends Structure {
|
||||
no_context = enable ? CBool.FALSE : CBool.TRUE;
|
||||
}
|
||||
|
||||
/** Generate timestamps or not? */
|
||||
public CBool no_timestamps;
|
||||
|
||||
/** Flag to force single segment output (useful for streaming). (default = false) */
|
||||
public CBool single_segment;
|
||||
|
||||
@ -307,16 +304,10 @@ public class WhisperFullParams extends Structure {
|
||||
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
|
||||
}
|
||||
|
||||
/** Grammar stuff */
|
||||
public Pointer grammar_rules;
|
||||
public long n_grammar_rules;
|
||||
public long i_start_rule;
|
||||
public float grammar_penalty;
|
||||
|
||||
@Override
|
||||
protected List<String> getFieldOrder() {
|
||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
||||
"no_context", "single_segment", "no_timestamps",
|
||||
"no_context", "single_segment",
|
||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
|
||||
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||
@ -325,7 +316,6 @@ public class WhisperFullParams extends Structure {
|
||||
"new_segment_callback", "new_segment_callback_user_data",
|
||||
"progress_callback", "progress_callback_user_data",
|
||||
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||
"logits_filter_callback", "logits_filter_callback_user_data",
|
||||
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
|
||||
"logits_filter_callback", "logits_filter_callback_user_data");
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package io.github.ggerganov.whispercpp;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import io.github.ggerganov.whispercpp.bean.WhisperSegment;
|
||||
import io.github.ggerganov.whispercpp.params.CBool;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
|
||||
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
|
||||
@ -12,7 +11,6 @@ import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.util.List;
|
||||
|
||||
class WhisperCppTest {
|
||||
private static WhisperCpp whisper = new WhisperCpp();
|
||||
@ -22,12 +20,11 @@ class WhisperCppTest {
|
||||
static void init() throws FileNotFoundException {
|
||||
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
|
||||
// or you can provide the absolute path to the model file.
|
||||
//String modelName = "../../models/ggml-tiny.bin";
|
||||
String modelName = "../../models/ggml-tiny.en.bin";
|
||||
try {
|
||||
whisper.initContext(modelName);
|
||||
//whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
//whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
modelInitialised = true;
|
||||
} catch (FileNotFoundException ex) {
|
||||
System.out.println("Model " + modelName + " not found");
|
||||
@ -45,7 +42,7 @@ class WhisperCppTest {
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertFalse(params.translate);
|
||||
assertEquals(0.01f, params.thold_pt);
|
||||
assertEquals(5, params.beam_search.beam_size);
|
||||
assertEquals(2, params.beam_search.beam_size);
|
||||
assertEquals(-1.0f, params.beam_search.patience);
|
||||
}
|
||||
|
||||
@ -58,7 +55,7 @@ class WhisperCppTest {
|
||||
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
|
||||
assertNotEquals(0, params.n_threads);
|
||||
assertEquals(16384, params.n_max_text_ctx);
|
||||
assertEquals(5, params.greedy.best_of);
|
||||
assertEquals(2, params.greedy.best_of);
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -75,11 +72,11 @@ class WhisperCppTest {
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
// params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
|
||||
try {
|
||||
@ -102,43 +99,4 @@ class WhisperCppTest {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFullTranscribeWithTime() throws Exception {
|
||||
if (!modelInitialised) {
|
||||
System.out.println("Model not initialised, skipping test");
|
||||
return;
|
||||
}
|
||||
|
||||
// Given
|
||||
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
|
||||
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
|
||||
|
||||
byte[] b = new byte[audioInputStream.available()];
|
||||
float[] floats = new float[b.length / 2];
|
||||
|
||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||
params.print_progress = CBool.FALSE;
|
||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||
|
||||
try {
|
||||
audioInputStream.read(b);
|
||||
|
||||
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
|
||||
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
|
||||
floats[j] = intSample / 32767.0f;
|
||||
}
|
||||
|
||||
List<WhisperSegment> segments = whisper.fullTranscribeWithTime(params, floats);
|
||||
assertTrue(segments.size() > 0, "The size of segments should be greater than 0");
|
||||
for (WhisperSegment segment : segments) {
|
||||
System.out.println(segment);
|
||||
}
|
||||
} finally {
|
||||
audioInputStream.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "whisper.cpp",
|
||||
"version": "1.5.1",
|
||||
"version": "1.4.3",
|
||||
"description": "Whisper speech recognition",
|
||||
"main": "whisper.js",
|
||||
"scripts": {
|
||||
|
@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
||||
|
||||
/**
|
||||
Make a prediction using the convenience interface
|
||||
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
|
||||
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
|
||||
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
||||
@return the prediction as whisper_encoder_implOutput
|
||||
*/
|
||||
|
@ -3,8 +3,6 @@
|
||||
// Code is derived from the work of Github user @wangchou
|
||||
// ref: https://github.com/wangchou/callCoreMLFromCpp
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#if __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@ -16,8 +14,6 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
int64_t n_ctx,
|
||||
int64_t n_mel,
|
||||
float * mel,
|
||||
float * out);
|
||||
|
||||
|
@ -48,15 +48,13 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
|
||||
|
||||
void whisper_coreml_encode(
|
||||
const whisper_coreml_context * ctx,
|
||||
int64_t n_ctx,
|
||||
int64_t n_mel,
|
||||
float * mel,
|
||||
float * out) {
|
||||
MLMultiArray * inMultiArray = [
|
||||
[MLMultiArray alloc] initWithDataPointer: mel
|
||||
shape: @[@1, @(n_mel), @(n_ctx)]
|
||||
shape: @[@1, @80, @3000]
|
||||
dataType: MLMultiArrayDataTypeFloat32
|
||||
strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
|
||||
strides: @[@(240000), @(3000), @1]
|
||||
deallocator: nil
|
||||
error: nil
|
||||
];
|
||||
|
@ -23,7 +23,6 @@ add_library(${TARGET} STATIC
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
grammar-parser.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
@ -65,7 +64,6 @@ elseif(CMAKE_JS_VERSION)
|
||||
else()
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(stream)
|
||||
add_subdirectory(server)
|
||||
add_subdirectory(command)
|
||||
add_subdirectory(bench)
|
||||
add_subdirectory(quantize)
|
||||
|
@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
}
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
@ -104,30 +104,20 @@ int whisper_bench_full(const whisper_params & params) {
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// batched decoding
|
||||
for (int i = 0; i < 64; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
// prompt processing
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to decode: %d\n", ret);
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,6 @@
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
@ -22,11 +21,6 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
bool file_exists(const std::string & fname) {
|
||||
std::ifstream f(fname.c_str());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -36,12 +30,8 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
@ -55,8 +45,6 @@ struct whisper_params {
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string context;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -87,9 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -124,30 +109,16 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string & grammar_rule,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms) {
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
logprob_min = 0.0f;
|
||||
logprob_sum = 0.0f;
|
||||
n_tokens = 0;
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
@ -155,41 +126,19 @@ std::string transcribe(
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 5;
|
||||
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
||||
if (grammar_parsed.symbol_ids.find(grammar_rule) == grammar_parsed.symbol_ids.end()) {
|
||||
fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, grammar_rule.c_str());
|
||||
} else {
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
}
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
@ -198,17 +147,19 @@ std::string transcribe(
|
||||
|
||||
result += text;
|
||||
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
if(token.plog > 0.0f) exit(0);
|
||||
logprob_min = std::min(logprob_min, token.plog);
|
||||
logprob_sum += token.plog;
|
||||
++n_tokens;
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
@ -299,7 +250,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
@ -467,9 +418,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
int n_tokens = 0;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
|
||||
@ -507,7 +456,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
@ -543,27 +492,18 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min0 = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum0 = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens0 = 0;
|
||||
int n_tokens = 0;
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
if (!params.prompt.empty()) {
|
||||
k_prompt = params.prompt;
|
||||
}
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@ -596,11 +536,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min0);
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
@ -621,30 +559,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= txt.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
@ -657,16 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
if (best_len == 0) {
|
||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||
} else {
|
||||
// cut the prompt from the decoded text
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
}
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
@ -734,36 +654,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
auto & grammar = params.grammar_parsed;
|
||||
if (file_exists(params.grammar.c_str())) {
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
} else {
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
@ -9,11 +9,6 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
|
||||
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
|
||||
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
|
||||
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
|
||||
{"q2_k", GGML_FTYPE_MOSTLY_Q2_K},
|
||||
{"q3_k", GGML_FTYPE_MOSTLY_Q3_K},
|
||||
{"q4_k", GGML_FTYPE_MOSTLY_Q4_K},
|
||||
{"q5_k", GGML_FTYPE_MOSTLY_Q5_K},
|
||||
{"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
|
||||
};
|
||||
|
||||
void ggml_print_ftypes(FILE * fp) {
|
||||
@ -53,15 +48,15 @@ bool ggml_common_quantize_0(
|
||||
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
|
||||
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
|
||||
case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;
|
||||
case GGML_FTYPE_UNKNOWN:
|
||||
case GGML_FTYPE_ALL_F32:
|
||||
case GGML_FTYPE_MOSTLY_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
|
||||
case GGML_FTYPE_MOSTLY_Q2_K:
|
||||
case GGML_FTYPE_MOSTLY_Q3_K:
|
||||
case GGML_FTYPE_MOSTLY_Q4_K:
|
||||
case GGML_FTYPE_MOSTLY_Q5_K:
|
||||
case GGML_FTYPE_MOSTLY_Q6_K:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
|
||||
return false;
|
||||
@ -172,17 +167,24 @@ bool ggml_common_quantize_0(
|
||||
|
||||
switch ((ggml_type) ttype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data());
|
||||
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
|
||||
} break;
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
@ -190,6 +192,11 @@ bool ggml_common_quantize_0(
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_COUNT:
|
||||
{
|
||||
|
@ -139,13 +139,10 @@ void audio_async::callback(uint8_t * stream, int len) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t n_samples = len / sizeof(float);
|
||||
const size_t n_samples = len / sizeof(float);
|
||||
|
||||
if (n_samples > m_audio.size()) {
|
||||
n_samples = m_audio.size();
|
||||
|
||||
stream += (len - (n_samples * sizeof(float)));
|
||||
}
|
||||
m_audio_new.resize(n_samples);
|
||||
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
|
||||
|
||||
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
|
||||
|
||||
@ -156,7 +153,7 @@ void audio_async::callback(uint8_t * stream, int len) {
|
||||
const size_t n0 = m_audio.size() - m_audio_pos;
|
||||
|
||||
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
|
||||
memcpy(&m_audio[0], stream + n0 * sizeof(float), (n_samples - n0) * sizeof(float));
|
||||
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
|
||||
|
||||
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
|
||||
m_audio_len = m_audio.size();
|
||||
|
@ -41,6 +41,7 @@ private:
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::vector<float> m_audio;
|
||||
std::vector<float> m_audio_new;
|
||||
size_t m_audio_pos = 0;
|
||||
size_t m_audio_len = 0;
|
||||
};
|
||||
|
@ -181,7 +181,7 @@ private:
|
||||
// It is assumed that PCM data is normalized to a range from -1 to 1
|
||||
bool write_audio(const float * data, size_t length) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
const int16_t intSample = data[i] * 32767;
|
||||
const auto intSample = static_cast<const int16_t>(data[i] * 32767);
|
||||
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
|
||||
dataSize += sizeof(int16_t);
|
||||
}
|
||||
|
@ -1,423 +0,0 @@
|
||||
#include "grammar-parser.h"
|
||||
#include <cstdint>
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
|
||||
namespace grammar_parser {
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
// copied from whisper.cpp
|
||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||
uint8_t highbits = first_byte >> 4;
|
||||
int len = lookup[highbits];
|
||||
uint8_t mask = (1 << (8 - len)) - 1;
|
||||
uint32_t value = first_byte & mask;
|
||||
const char * end = src + len; // may overrun!
|
||||
const char * pos = src + 1;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||
return next_id;
|
||||
}
|
||||
|
||||
void add_rule(
|
||||
parse_state & state,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule) {
|
||||
if (state.rules.size() <= rule_id) {
|
||||
state.rules.resize(rule_id + 1);
|
||||
}
|
||||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
const char * pos = src;
|
||||
const char * end = src + size;
|
||||
uint32_t value = 0;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value <<= 4;
|
||||
char c = *pos;
|
||||
if ('a' <= c && c <= 'f') {
|
||||
value += c - 'a' + 10;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
value += c - 'A' + 10;
|
||||
} else if ('0' <= c && c <= '9') {
|
||||
value += c - '0';
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos != end) {
|
||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
const char * parse_space(const char * src, bool newline_ok) {
|
||||
const char * pos = src;
|
||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||
if (*pos == '#') {
|
||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||
pos++;
|
||||
}
|
||||
} else {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_name(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_word_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
case 'x': return parse_hex(src + 2, 2);
|
||||
case 'u': return parse_hex(src + 2, 4);
|
||||
case 'U': return parse_hex(src + 2, 8);
|
||||
case 't': return std::make_pair('\t', src + 2);
|
||||
case 'r': return std::make_pair('\r', src + 2);
|
||||
case 'n': return std::make_pair('\n', src + 2);
|
||||
case '\\':
|
||||
case '"':
|
||||
case '[':
|
||||
case ']':
|
||||
return std::make_pair(src[1], src + 2);
|
||||
default:
|
||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||
}
|
||||
} else if (*src) {
|
||||
return decode_utf8(src);
|
||||
}
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_sequence(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
std::vector<whisper_grammar_element> & out_elements,
|
||||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != '"') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '[') { // char range(s)
|
||||
pos++;
|
||||
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
|
||||
if (*pos == '^') {
|
||||
pos++;
|
||||
start_type = WHISPER_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != ']') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
enum whisper_gretype type = last_sym_start < out_elements.size()
|
||||
? WHISPER_GRETYPE_CHAR_ALT
|
||||
: start_type;
|
||||
|
||||
out_elements.push_back({type, char_pair.first});
|
||||
if (pos[0] == '-' && pos[1] != ']') {
|
||||
auto endchar_pair = parse_char(pos + 1);
|
||||
pos = endchar_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||
last_sym_start = out_elements.size();
|
||||
// output reference to synthesized rule
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
if (*pos != ')') {
|
||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// rewrite rules:
|
||||
// S* --> S' ::= S S' |
|
||||
// S+ --> S' ::= S S' | S
|
||||
// S? --> S' ::= S |
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<whisper_grammar_element> sub_rule;
|
||||
// add preceding symbol to generated rule
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (*pos == '*' || *pos == '+') {
|
||||
// cause generated rule to recurse
|
||||
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
}
|
||||
// mark start of alternate def
|
||||
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
if (*pos == '+') {
|
||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
}
|
||||
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested) {
|
||||
std::vector<whisper_grammar_element> rule;
|
||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||
while (*pos == '|') {
|
||||
rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
pos = parse_space(pos + 1, true);
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||
}
|
||||
rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, rule_id, rule);
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_rule(parse_state & state, const char * src) {
|
||||
const char * name_end = parse_name(src);
|
||||
const char * pos = parse_space(name_end, false);
|
||||
size_t name_len = name_end - src;
|
||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||
const std::string name(src, name_len);
|
||||
|
||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 3, true);
|
||||
|
||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||
|
||||
if (*pos == '\r') {
|
||||
pos += pos[1] == '\n' ? 2 : 1;
|
||||
} else if (*pos == '\n') {
|
||||
pos++;
|
||||
} else if (*pos) {
|
||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||
}
|
||||
return parse_space(pos, true);
|
||||
}
|
||||
|
||||
parse_state parse(const char * src) {
|
||||
try {
|
||||
parse_state state;
|
||||
const char * pos = parse_space(src, true);
|
||||
while (*pos) {
|
||||
pos = parse_rule(state, pos);
|
||||
}
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
}
|
||||
}
|
||||
|
||||
void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
} else {
|
||||
// cop out of encoding UTF-8
|
||||
fprintf(file, "<U+%04X>", c);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_char_element(whisper_grammar_element elem) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_CHAR: return true;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
|
||||
for (auto elem : rule) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
|
||||
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "(%u) ", elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
fprintf(file, "(\"");
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_rule(
|
||||
FILE * file,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule,
|
||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
|
||||
throw std::runtime_error(
|
||||
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
|
||||
}
|
||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
whisper_grammar_element elem = rule[i];
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
throw std::runtime_error(
|
||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||
std::to_string(i));
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
fprintf(file, "| ");
|
||||
break;
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
fprintf(file, "[");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
fprintf(file, "[^");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
fprintf(file, "-");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
break;
|
||||
default:
|
||||
fprintf(file, "] ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_grammar(FILE * file, const parse_state & state) {
|
||||
try {
|
||||
std::map<uint32_t, std::string> symbol_id_names;
|
||||
for (auto kv : state.symbol_ids) {
|
||||
symbol_id_names[kv.second] = kv.first;
|
||||
}
|
||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||
// fprintf(file, "%zu: ", i);
|
||||
// print_rule_binary(file, state.rules[i]);
|
||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
||||
// fprintf(file, "\n");
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
||||
std::vector<const whisper_grammar_element *> ret;
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||
// binary context-free grammar format specified by whisper.h. Supports character
|
||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||
// arithmetic might look like:
|
||||
//
|
||||
// root ::= expr
|
||||
// expr ::= term ([-+*/] term)*
|
||||
// term ::= num | "(" space expr ")" space
|
||||
// num ::= [0-9]+ space
|
||||
// space ::= [ \t\n]*
|
||||
|
||||
#pragma once
|
||||
#include "whisper.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace grammar_parser {
|
||||
struct parse_state {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
|
||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
void print_grammar(FILE * file, const parse_state & state);
|
||||
}
|
@ -48,7 +48,7 @@ if [ -n "$3" ]; then
|
||||
fi
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
|
@ -62,8 +62,8 @@ struct whisper_params {
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
@ -165,8 +165,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
|
@ -1,6 +0,0 @@
|
||||
set(TARGET server)
|
||||
add_executable(${TARGET} server.cpp httplib.h json.hpp)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
|
@ -1,68 +0,0 @@
|
||||
# whisper.cpp http server
|
||||
|
||||
Simple http server. WAV Files are passed to the inference model via http requests.
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
./server -h
|
||||
|
||||
usage: ./bin/server [options]
|
||||
|
||||
options:
|
||||
-h, --help [default] show this help message and exit
|
||||
-t N, --threads N [4 ] number of threads to use during computation
|
||||
-p N, --processors N [1 ] number of processors to use during computation
|
||||
-ot N, --offset-t N [0 ] time offset in milliseconds
|
||||
-on N, --offset-n N [0 ] segment index offset
|
||||
-d N, --duration N [0 ] duration of audio to process in milliseconds
|
||||
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
|
||||
-ml N, --max-len N [0 ] maximum segment length in characters
|
||||
-sow, --split-on-word [false ] split on word rather than on token
|
||||
-bo N, --best-of N [2 ] number of best candidates to keep
|
||||
-bs N, --beam-size N [-1 ] beam size for beam search
|
||||
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
|
||||
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
|
||||
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
|
||||
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
|
||||
-tr, --translate [false ] translate from source language to english
|
||||
-di, --diarize [false ] stereo audio diarization
|
||||
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
|
||||
-nf, --no-fallback [false ] do not use temperature fallback while decoding
|
||||
-ps, --print-special [false ] print special tokens
|
||||
-pc, --print-colors [false ] print colors
|
||||
-pr, --print-realtime [false ] print output in realtime
|
||||
-pp, --print-progress [false ] print progress
|
||||
-nt, --no-timestamps [false ] do not print timestamps
|
||||
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
|
||||
-dl, --detect-language [false ] exit after automatically detecting language
|
||||
--prompt PROMPT [ ] initial prompt
|
||||
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
|
||||
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
|
||||
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
|
||||
--port PORT, [8080 ] Port number for the server
|
||||
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
|
||||
|
||||
## request examples
|
||||
|
||||
**/inference**
|
||||
```
|
||||
curl 127.0.0.1:8080/inference \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F file="@<file-path>" \
|
||||
-F temperature="0.2" \
|
||||
-F response-format="json"
|
||||
```
|
||||
|
||||
**/load**
|
||||
```
|
||||
curl 127.0.0.1:8080/load \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F model="<path-to-model-file>"
|
||||
```
|
24596
examples/server/json.hpp
@ -1,762 +0,0 @@
|
||||
#include "common.h"
|
||||
|
||||
#include "whisper.h"
|
||||
#include "httplib.h"
|
||||
#include "json.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
using namespace httplib;
|
||||
using json = nlohmann::json;
|
||||
|
||||
namespace {
|
||||
|
||||
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
|
||||
// Lowest is red, middle is yellow, highest is green.
|
||||
const std::vector<std::string> k_colors = {
|
||||
"\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
|
||||
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
|
||||
};
|
||||
|
||||
// output formats
|
||||
const std::string json_format = "json";
|
||||
const std::string text_format = "text";
|
||||
const std::string srt_format = "srt";
|
||||
const std::string vjson_format = "verbose_json";
|
||||
const std::string vtt_format = "vtt";
|
||||
|
||||
struct server_params
|
||||
{
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = "examples/server/public";
|
||||
|
||||
int32_t port = 8080;
|
||||
int32_t read_timeout = 600;
|
||||
int32_t write_timeout = 600;
|
||||
|
||||
bool ffmpeg_converter = false;
|
||||
};
|
||||
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
int32_t n_processors = 1;
|
||||
int32_t offset_t_ms = 0;
|
||||
int32_t offset_n = 0;
|
||||
int32_t duration_ms = 0;
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
float userdef_temp = 0.20f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool debug_mode = false;
|
||||
bool translate = false;
|
||||
bool detect_language = false;
|
||||
bool diarize = false;
|
||||
bool tinydiarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_realtime = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool use_gpu = true;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
std::string response_format = json_format;
|
||||
|
||||
// [TDRZ] speaker turn string
|
||||
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
||||
|
||||
std::string openvino_encode_device = "CPU";
|
||||
};
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
int64_t msec = t * 10;
|
||||
int64_t hr = msec / (1000 * 60 * 60);
|
||||
msec = msec - hr * (1000 * 60 * 60);
|
||||
int64_t min = msec / (1000 * 60);
|
||||
msec = msec - min * (1000 * 60);
|
||||
int64_t sec = msec / 1000;
|
||||
msec = msec - sec * 1000;
|
||||
|
||||
char buf[32];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
|
||||
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
int timestamp_to_sample(int64_t t, int n_samples) {
|
||||
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
||||
}
|
||||
|
||||
bool is_file_exist(const char *fileName)
|
||||
{
|
||||
std::ifstream infile(fileName);
|
||||
return infile.good();
|
||||
}
|
||||
|
||||
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params,
|
||||
const server_params& sparams) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "usage: %s [options] \n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
|
||||
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
|
||||
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
|
||||
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
|
||||
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
|
||||
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
|
||||
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
|
||||
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
|
||||
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
|
||||
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
|
||||
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
|
||||
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
|
||||
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
||||
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
|
||||
fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false");
|
||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
|
||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
|
||||
// server params
|
||||
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
|
||||
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
|
||||
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
|
||||
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) {
|
||||
for (int i = 1; i < argc; i++) {
|
||||
std::string arg = argv[i];
|
||||
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
||||
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
// server params
|
||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
||||
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
|
||||
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct whisper_print_user_data {
|
||||
const whisper_params * params;
|
||||
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
int progress_prev;
|
||||
};
|
||||
|
||||
void check_ffmpeg_availibility() {
|
||||
int result = system("ffmpeg -version");
|
||||
|
||||
if (result == 0) {
|
||||
std::cout << "ffmpeg is available." << std::endl;
|
||||
} else {
|
||||
// ffmpeg is not available
|
||||
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed ";
|
||||
std::cout << "and that its executable is included in your system's PATH. ";
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
|
||||
std::ostringstream cmd_stream;
|
||||
std::string converted_filename_temp = temp_filename + "_temp.wav";
|
||||
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
|
||||
std::string cmd = cmd_stream.str();
|
||||
|
||||
int status = std::system(cmd.c_str());
|
||||
if (status != 0) {
|
||||
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the original file
|
||||
if (remove(temp_filename.c_str()) != 0) {
|
||||
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Rename the temporary file to match the original filename
|
||||
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
|
||||
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
|
||||
std::string speaker = "";
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "0";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "1";
|
||||
} else {
|
||||
speaker = "?";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
|
||||
if (!id_only) {
|
||||
speaker.insert(0, "(speaker ");
|
||||
speaker.append(")");
|
||||
}
|
||||
|
||||
return speaker;
|
||||
}
|
||||
|
||||
void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
|
||||
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
|
||||
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
|
||||
if (progress >= *progress_prev + progress_step) {
|
||||
*progress_prev += progress_step;
|
||||
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
|
||||
}
|
||||
}
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
|
||||
std::string speaker = "";
|
||||
|
||||
int64_t t0 = 0;
|
||||
int64_t t1 = 0;
|
||||
|
||||
// print the last n_new segments
|
||||
const int s0 = n_segments - n_new;
|
||||
|
||||
if (s0 == 0) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
for (int i = s0; i < n_segments; i++) {
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
}
|
||||
|
||||
if (!params.no_timestamps) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
if (params.print_special == false) {
|
||||
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||
if (id >= whisper_token_eot(ctx)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
|
||||
if (params.tinydiarize) {
|
||||
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
|
||||
printf("%s", params.tdrz_speaker_turn.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
}
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::stringstream result;
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
result << speaker << text << "\n";
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void get_req_parameters(const Request & req, whisper_params & params)
|
||||
{
|
||||
// user model configu.has_fileion
|
||||
if (req.has_file("offset-t"))
|
||||
{
|
||||
params.offset_t_ms = std::stoi(req.get_file_value("offset-t").content);
|
||||
}
|
||||
if (req.has_file("offset-n"))
|
||||
{
|
||||
params.offset_n = std::stoi(req.get_file_value("offset-n").content);
|
||||
}
|
||||
if (req.has_file("duration"))
|
||||
{
|
||||
params.duration_ms = std::stoi(req.get_file_value("duration").content);
|
||||
}
|
||||
if (req.has_file("max-context"))
|
||||
{
|
||||
params.max_context = std::stoi(req.get_file_value("max-context").content);
|
||||
}
|
||||
if (req.has_file("prompt"))
|
||||
{
|
||||
params.prompt = req.get_file_value("prompt").content;
|
||||
}
|
||||
if (req.has_file("response-format"))
|
||||
{
|
||||
params.response_format = req.get_file_value("response-format").content;
|
||||
}
|
||||
if (req.has_file("temperature"))
|
||||
{
|
||||
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
whisper_params params;
|
||||
server_params sparams;
|
||||
|
||||
std::mutex whisper_mutex;
|
||||
|
||||
if (whisper_params_parse(argc, argv, params, sparams) == false) {
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (params.diarize && params.tinydiarize) {
|
||||
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
|
||||
whisper_print_usage(argc, argv, params, sparams);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (sparams.ffmpeg_converter) {
|
||||
check_ffmpeg_availibility();
|
||||
}
|
||||
// whisper init
|
||||
struct whisper_context_params cparams;
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: failed to initialize whisper context\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
|
||||
Server svr;
|
||||
|
||||
std::string const default_content = "<html>hello</html>";
|
||||
|
||||
// this is only called if no index.html is found in the public --path
|
||||
svr.Get("/", [&default_content](const Request &, Response &res){
|
||||
res.set_content(default_content, "text/html");
|
||||
return false;
|
||||
});
|
||||
|
||||
svr.Post("/inference", [&](const Request &req, Response &res){
|
||||
// acquire whisper model mutex lock
|
||||
whisper_mutex.lock();
|
||||
|
||||
// first check user requested fields of the request
|
||||
if (!req.has_file("file"))
|
||||
{
|
||||
fprintf(stderr, "error: no 'file' field in the request\n");
|
||||
const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
|
||||
res.set_content(error_resp, "application/json");
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
auto audio_file = req.get_file_value("file");
|
||||
|
||||
// check non-required fields
|
||||
get_req_parameters(req, params);
|
||||
|
||||
std::string filename{audio_file.filename};
|
||||
printf("Received request: %s\n", filename.c_str());
|
||||
|
||||
// audio arrays
|
||||
std::vector<float> pcmf32; // mono-channel F32 PCM
|
||||
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
|
||||
|
||||
// write to temporary file
|
||||
const std::string temp_filename = "whisper_server_temp_file.wav";
|
||||
std::ofstream temp_file{temp_filename, std::ios::binary};
|
||||
temp_file << audio_file.content;
|
||||
temp_file.close();
|
||||
|
||||
// if file is not wav, convert to wav
|
||||
|
||||
if (sparams.ffmpeg_converter) {
|
||||
std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
|
||||
const bool is_converted = convert_to_wav(temp_filename, error_resp);
|
||||
if (!is_converted) {
|
||||
res.set_content(error_resp, "application/json");
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// read wav content into pcmf32
|
||||
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
|
||||
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
|
||||
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
|
||||
res.set_content(error_resp, "application/json");
|
||||
std::remove(temp_filename.c_str());
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
// remove temp file
|
||||
std::remove(temp_filename.c_str());
|
||||
|
||||
printf("Successfully loaded %s\n", filename.c_str());
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
||||
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
|
||||
}
|
||||
|
||||
// print some info about the processing
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
if (!whisper_is_multilingual(ctx)) {
|
||||
if (params.language != "en" || params.translate) {
|
||||
params.language = "en";
|
||||
params.translate = false;
|
||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||
}
|
||||
}
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, filename.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// run the inference
|
||||
{
|
||||
printf("Running whisper.cpp inference on %s\n", filename.c_str());
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
|
||||
|
||||
wparams.print_realtime = false;
|
||||
wparams.print_progress = params.print_progress;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.detect_language = params.detect_language;
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
wparams.duration_ms = params.duration_ms;
|
||||
|
||||
wparams.thold_pt = params.word_thold;
|
||||
wparams.split_on_word = params.split_on_word;
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.debug_mode = params.debug_mode;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
wparams.beam_search.beam_size = params.beam_size;
|
||||
|
||||
wparams.temperature_inc = params.userdef_temp;
|
||||
wparams.entropy_thold = params.entropy_thold;
|
||||
wparams.logprob_thold = params.logprob_thold;
|
||||
|
||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
|
||||
|
||||
// this callback is called on each new segment
|
||||
if (params.print_realtime) {
|
||||
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||
wparams.new_segment_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
if (wparams.print_progress) {
|
||||
wparams.progress_callback = whisper_print_progress_callback;
|
||||
wparams.progress_callback_user_data = &user_data;
|
||||
}
|
||||
|
||||
// examples for abort mechanism
|
||||
// in examples below, we do not abort the processing, but we could if the flag is set to true
|
||||
|
||||
// the callback is called before every encoder run - if it returns false, the processing is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return !is_aborted;
|
||||
};
|
||||
wparams.encoder_begin_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
// the callback is called before every computation - if it returns true, the computation is aborted
|
||||
{
|
||||
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
|
||||
|
||||
wparams.abort_callback = [](void * user_data) {
|
||||
bool is_aborted = *(bool*)user_data;
|
||||
return is_aborted;
|
||||
};
|
||||
wparams.abort_callback_user_data = &is_aborted;
|
||||
}
|
||||
|
||||
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
|
||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
|
||||
res.set_content(error_resp, "application/json");
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// return results to user
|
||||
if (params.response_format == text_format)
|
||||
{
|
||||
std::string results = output_str(ctx, params, pcmf32s);
|
||||
res.set_content(results.c_str(), "text/html");
|
||||
}
|
||||
// TODO add more output formats
|
||||
else
|
||||
{
|
||||
std::string results = output_str(ctx, params, pcmf32s);
|
||||
json jres = json{
|
||||
{"text", results}
|
||||
};
|
||||
res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||
"application/json");
|
||||
}
|
||||
|
||||
// return whisper model mutex lock
|
||||
whisper_mutex.unlock();
|
||||
});
|
||||
svr.Post("/load", [&](const Request &req, Response &res){
|
||||
whisper_mutex.lock();
|
||||
if (!req.has_file("model"))
|
||||
{
|
||||
fprintf(stderr, "error: no 'model' field in the request\n");
|
||||
const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
|
||||
res.set_content(error_resp, "application/json");
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
std::string model = req.get_file_value("model").content;
|
||||
if (!is_file_exist(model.c_str()))
|
||||
{
|
||||
fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
|
||||
const std::string error_resp = "{\"error\":\"model not found!\"}";
|
||||
res.set_content(error_resp, "application/json");
|
||||
whisper_mutex.unlock();
|
||||
return;
|
||||
}
|
||||
|
||||
// clean up
|
||||
whisper_free(ctx);
|
||||
|
||||
// whisper init
|
||||
ctx = whisper_init_from_file_with_params(model.c_str(), cparams);
|
||||
|
||||
// TODO perhaps load prior model here instead of exit
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "error: model init failed, no model loaded must exit\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
|
||||
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
|
||||
|
||||
const std::string success = "Load was successful!";
|
||||
res.set_content(success, "application/text");
|
||||
|
||||
// check if the model is in the file system
|
||||
whisper_mutex.unlock();
|
||||
});
|
||||
|
||||
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
|
||||
const char fmt[] = "500 Internal Server Error\n%s";
|
||||
char buf[BUFSIZ];
|
||||
try {
|
||||
std::rethrow_exception(std::move(ep));
|
||||
} catch (std::exception &e) {
|
||||
snprintf(buf, sizeof(buf), fmt, e.what());
|
||||
} catch (...) {
|
||||
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
|
||||
}
|
||||
res.set_content(buf, "text/plain");
|
||||
res.status = 500;
|
||||
});
|
||||
|
||||
svr.set_error_handler([](const Request &, Response &res) {
|
||||
if (res.status == 400) {
|
||||
res.set_content("Invalid request", "text/plain");
|
||||
} else if (res.status != 500) {
|
||||
res.set_content("File Not Found", "text/plain");
|
||||
res.status = 404;
|
||||
}
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
svr.set_read_timeout(sparams.read_timeout);
|
||||
svr.set_write_timeout(sparams.write_timeout);
|
||||
|
||||
if (!svr.bind_to_port(sparams.hostname, sparams.port))
|
||||
{
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
|
||||
sparams.hostname.c_str(), sparams.port);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Set the base directory for serving static files
|
||||
svr.set_base_dir(sparams.public_path);
|
||||
|
||||
// to make it ctrl+clickable:
|
||||
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||
|
||||
if (!svr.listen_after_bind())
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
@ -53,7 +53,6 @@ struct whisper_params {
|
||||
int32_t capture_id = -1;
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
int32_t n_gpu_layers = 999;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
@ -91,7 +90,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
|
||||
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
@ -136,7 +134,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
||||
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
||||
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
|
||||
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
||||
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
@ -251,7 +248,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
|
||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
@ -271,8 +268,6 @@ int main(int argc, char ** argv) {
|
||||
auto lmparams = llama_model_default_params();
|
||||
if (!params.use_gpu) {
|
||||
lmparams.n_gpu_layers = 0;
|
||||
} else {
|
||||
lmparams.n_gpu_layers = params.n_gpu_layers;
|
||||
}
|
||||
|
||||
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);
|
||||
@ -686,8 +681,8 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, "'", "'\"'\"'");
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + text_to_speak + "'").c_str());
|
||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "%s: failed to speak\n", __func__);
|
||||
}
|
||||
|
@ -121,13 +121,13 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
return false;
|
||||
}
|
||||
|
||||
char word[129];
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
word[len] = '\0';
|
||||
fin.read((char *) word, len);
|
||||
|
||||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
|
@ -21,7 +21,7 @@ help()
|
||||
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
|
||||
echo "options:"
|
||||
echo "-s Step in seconds (default is $step)."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large-v3' (default is '$model')."
|
||||
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large' (default is '$model')."
|
||||
echo "-t Number of threads to use."
|
||||
echo "-h Print this help page."
|
||||
echo
|
||||
|
15
examples/whisper.android.java/.gitignore
vendored
@ -1,15 +0,0 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
@ -1,20 +0,0 @@
|
||||
A sample Android app using java code and [whisper.cpp](https://github.com/ggerganov/whisper.cpp/) to do voice-to-text transcriptions.
|
||||
|
||||
To use:
|
||||
|
||||
1. Select a model from the [whisper.cpp repository](https://github.com/ggerganov/whisper.cpp/tree/master/models).[^1]
|
||||
2. Copy the model to the "app/src/main/assets/models" folder.
|
||||
3. Select a sample audio file (for example, [jfk.wav](https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav)).
|
||||
4. Copy the sample to the "app/src/main/assets/samples" folder.
|
||||
5. Modify the modelFilePath in the WhisperService.java
|
||||
6. Modify the sampleFilePath in the WhisperService.java
|
||||
7. Select the "release" active build variant, and use Android Studio to run and deploy to your device.
|
||||
[^1]: I recommend the tiny or base models for running on an Android device.
|
||||
|
||||
PS:
|
||||
1. Do not move this android project folder individually to other folders, because this android project folder depends on the files of the whole project.
|
||||
2. The cpp code is compiled during the build process
|
||||
3. If you want to import a compiled cpp project in your Android project, please refer to the https://github.com/litongjava/whisper.cpp.android.java.demo
|
||||
|
||||

|
||||
|
Before Width: | Height: | Size: 67 KiB |
1
examples/whisper.android.java/app/.gitignore
vendored
@ -1 +0,0 @@
|
||||
/build
|
@ -1,58 +0,0 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
}
|
||||
|
||||
android {
|
||||
compileSdkVersion 30
|
||||
buildToolsVersion '30.0.3'
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.litongjava.whisper.android.java"
|
||||
minSdkVersion 21
|
||||
targetSdkVersion 30
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
cppFlags ""
|
||||
}
|
||||
}
|
||||
ndk {
|
||||
abiFilters 'arm64-v8a', 'armeabi-v7a', 'x86', 'x86_64'
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
signingConfig signingConfigs.debug
|
||||
minifyEnabled true
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path "src/main/jni/whisper/CMakeLists.txt"
|
||||
}
|
||||
}
|
||||
ndkVersion "25.2.9519653"
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.appcompat:appcompat:1.1.0'
|
||||
implementation 'com.google.android.material:material:1.1.0'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
|
||||
testImplementation 'junit:junit:4.+'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
||||
|
||||
//litongjava
|
||||
implementation 'com.litongjava:android-view-inject:1.0'
|
||||
implementation 'com.litongjava:jfinal-aop:1.0.1'
|
||||
implementation 'com.litongjava:litongjava-android-utils:1.0.0'
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
@ -1,26 +0,0 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
|
||||
*/
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class ExampleInstrumentedTest {
|
||||
@Test
|
||||
public void useAppContext() {
|
||||
// Context of the app under test.
|
||||
Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
|
||||
assertEquals("com.litongjava.whisper.android.java", appContext.getPackageName());
|
||||
}
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.litongjava.whisper.android.java">
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:name=".app.App"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.Whisperandroidjava">
|
||||
<activity android:name=".MainActivity">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
@ -1,40 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<configuration debug="false" xmlns="http://ch.qos.logback/xml/ns/logback"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://ch.qos.logback/xml/ns/logback https://raw.githubusercontent.com/enricopulatzo/logback-XSD/master/src/main/xsd/logback.xsd
|
||||
http://ch.qos.logback/xml/ns/logback ">
|
||||
<!--Define the storage address of the log file Do not use relative paths in the LogBack configuration. -->
|
||||
<property name="LOG_HOME" value="logs" />
|
||||
<!--Formatted output: %d means the date, %-6level: log level from the left display 6 characters wide, %m: log message, %n is a newline character -->
|
||||
<property name="CONSOLE_LOG_PATTERN"
|
||||
value="%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-6level%logger{0}.%M:%L - %m%n" />
|
||||
|
||||
<!-- console output -->
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
|
||||
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<!-- Generate log files on a daily basis -->
|
||||
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
|
||||
<encoder class="ch.qos.logback.classic.encoder.PatternLayoutEncoder">
|
||||
<pattern>${CONSOLE_LOG_PATTERN}</pattern>
|
||||
</encoder>
|
||||
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
|
||||
<!--File name for log file output -->
|
||||
<fileNamePattern>${LOG_HOME}/project-name-%d{yyyy-MM-dd}.log</fileNamePattern>
|
||||
<!--Maximum size of log file -->
|
||||
<maxHistory>180</maxHistory>
|
||||
</rollingPolicy>
|
||||
<!--日志文件最大的大小 -->
|
||||
<triggeringPolicy class="ch.qos.logback.core.rolling.SizeBasedTriggeringPolicy">
|
||||
<maxFileSize>10MB</maxFileSize>
|
||||
</triggeringPolicy>
|
||||
</appender>
|
||||
<!-- Log output level and source-->
|
||||
<root level="info">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
</configuration>
|
@ -1,107 +0,0 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.Looper;
|
||||
import android.view.View;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.android.view.inject.annotation.FindViewById;
|
||||
import com.litongjava.android.view.inject.annotation.FindViewByIdLayout;
|
||||
import com.litongjava.android.view.inject.annotation.OnClick;
|
||||
import com.litongjava.android.view.inject.utils.ViewInjectUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.jfinal.aop.AopManager;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
import com.litongjava.whisper.android.java.task.LoadModelTask;
|
||||
import com.litongjava.whisper.android.java.task.TranscriptionTask;
|
||||
import com.litongjava.whisper.android.java.utils.AssetUtils;
|
||||
import com.whispercpp.java.whisper.WhisperLib;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
|
||||
@FindViewByIdLayout(R.layout.activity_main)
|
||||
public class MainActivity extends AppCompatActivity {
|
||||
|
||||
@FindViewById(R.id.sample_text)
|
||||
private TextView tv;
|
||||
|
||||
Logger log = LoggerFactory.getLogger(this.getClass());
|
||||
private WhisperService whisperService = Aop.get(WhisperService.class);
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
//setContentView(R.layout.activity_main);
|
||||
ViewInjectUtils.injectActivity(this, this);
|
||||
initAopBean();
|
||||
showSystemInfo();
|
||||
}
|
||||
|
||||
private void initAopBean() {
|
||||
Handler mainHandler = new Handler(Looper.getMainLooper());
|
||||
AopManager.me().addSingletonObject(mainHandler);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@OnClick(R.id.loadModelBtn)
|
||||
public void loadModelBtn_OnClick(View v) {
|
||||
Context context = getBaseContext();
|
||||
ThreadUtils.executeByIo(new LoadModelTask(tv));
|
||||
}
|
||||
|
||||
@OnClick(R.id.transcriptSampleBtn)
|
||||
public void transcriptSampleBtn_OnClick(View v) {
|
||||
Context context = getBaseContext();
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
String sampleFilePath = "samples/jfk.wav";
|
||||
File filesDir = context.getFilesDir();
|
||||
File sampleFile = AssetUtils.copyFileIfNotExists(context, filesDir, sampleFilePath);
|
||||
long end = System.currentTimeMillis();
|
||||
String msg = "copy file:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
ThreadUtils.executeByIo(new TranscriptionTask(tv, sampleFile));
|
||||
}
|
||||
|
||||
private void outputMsg(TextView tv, String msg) {
|
||||
tv.append(msg + "\n");
|
||||
log.info(msg);
|
||||
}
|
||||
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@OnClick(R.id.systemInfoBtn)
|
||||
public void systemInfoBtn_OnClick(View v) {
|
||||
showSystemInfo();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void showSystemInfo() {
|
||||
String systemInfo = WhisperLib.getSystemInfo();
|
||||
tv.append(systemInfo + "\n");
|
||||
}
|
||||
|
||||
@OnClick(R.id.clearBtn)
|
||||
public void clearBtn_OnClick(View v) {
|
||||
tv.setText("");
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
protected void onDestroy() {
|
||||
super.onDestroy();
|
||||
whisperService.release();
|
||||
}
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.app;
|
||||
|
||||
import android.app.Application;
|
||||
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
|
||||
public class App extends Application {
|
||||
@Override
|
||||
public void onCreate() {
|
||||
super.onCreate();
|
||||
Utils.init(this);
|
||||
}
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.bean;
|
||||
|
||||
/**
|
||||
* Created by litonglinux@qq.com on 10/21/2023_7:48 AM
|
||||
*/
|
||||
public class WhisperSegment {
|
||||
private long start, end;
|
||||
private String sentence;
|
||||
|
||||
public WhisperSegment() {
|
||||
}
|
||||
|
||||
public WhisperSegment(long start, long end, String sentence) {
|
||||
this.start = start;
|
||||
this.end = end;
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
public long getStart() {
|
||||
return start;
|
||||
}
|
||||
|
||||
public long getEnd() {
|
||||
return end;
|
||||
}
|
||||
|
||||
public String getSentence() {
|
||||
return sentence;
|
||||
}
|
||||
|
||||
public void setStart(long start) {
|
||||
this.start = start;
|
||||
}
|
||||
|
||||
public void setEnd(long end) {
|
||||
this.end = end;
|
||||
}
|
||||
|
||||
public void setSentence(String sentence) {
|
||||
this.sentence = sentence;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "["+start+" --> "+end+"]:"+sentence;
|
||||
}
|
||||
}
|
@ -1,101 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.services;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.blankj.utilcode.util.ToastUtils;
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
import com.litongjava.android.utils.dialog.AlertDialogUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
import com.litongjava.whisper.android.java.single.LocalWhisper;
|
||||
import com.litongjava.whisper.android.java.utils.WaveEncoder;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
public class WhisperService {
|
||||
private Logger log = LoggerFactory.getLogger(this.getClass());
|
||||
|
||||
private final Object lock = new Object();
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void loadModel(TextView tv) {
|
||||
String modelFilePath = LocalWhisper.modelFilePath;
|
||||
String msg = "load model from :" + modelFilePath + "\n";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
long start = System.currentTimeMillis();
|
||||
LocalWhisper.INSTANCE.init();
|
||||
long end = System.currentTimeMillis();
|
||||
msg = "model load successful:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
ToastUtils.showLong(msg);
|
||||
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void transcribeSample(TextView tv, File sampleFile) {
|
||||
String msg = "";
|
||||
msg = "transcribe file from :" + sampleFile.getAbsolutePath();
|
||||
outputMsg(tv, msg);
|
||||
|
||||
Long start = System.currentTimeMillis();
|
||||
float[] audioData = new float[0]; // 读取音频样本
|
||||
try {
|
||||
audioData = WaveEncoder.decodeWaveFile(sampleFile);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return;
|
||||
}
|
||||
long end = System.currentTimeMillis();
|
||||
msg = "decode wave file:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
start = System.currentTimeMillis();
|
||||
List<WhisperSegment> transcription = null;
|
||||
try {
|
||||
//transcription = LocalWhisper.INSTANCE.transcribeData(audioData);
|
||||
transcription = LocalWhisper.INSTANCE.transcribeDataWithTime(audioData);
|
||||
} catch (ExecutionException e) {
|
||||
e.printStackTrace();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
end = System.currentTimeMillis();
|
||||
if(transcription!=null){
|
||||
ToastUtils.showLong(transcription.toString());
|
||||
msg = "Transcript successful:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
|
||||
outputMsg(tv, transcription.toString());
|
||||
|
||||
}else{
|
||||
msg = "Transcript failed:" + (end - start) + "ms";
|
||||
outputMsg(tv, msg);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void outputMsg(TextView tv, String msg) {
|
||||
log.info(msg);
|
||||
if(tv!=null){
|
||||
Aop.get(Handler.class).post(()->{ tv.append(msg + "\n");});
|
||||
}
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void release() {
|
||||
//noting to do
|
||||
}
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.single;
|
||||
|
||||
import android.app.Application;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.blankj.utilcode.util.ToastUtils;
|
||||
import com.blankj.utilcode.util.Utils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
import com.litongjava.whisper.android.java.utils.AssetUtils;
|
||||
import com.whispercpp.java.whisper.WhisperContext;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public enum LocalWhisper {
|
||||
INSTANCE;
|
||||
|
||||
public static final String modelFilePath = "models/ggml-tiny.bin";
|
||||
private WhisperContext whisperContext;
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
LocalWhisper() {
|
||||
Application context = Utils.getApp();
|
||||
File filesDir = context.getFilesDir();
|
||||
File modelFile = AssetUtils.copyFileIfNotExists(context, filesDir, modelFilePath);
|
||||
String realModelFilePath = modelFile.getAbsolutePath();
|
||||
whisperContext = WhisperContext.createContextFromFile(realModelFilePath);
|
||||
}
|
||||
|
||||
public synchronized String transcribeData(float[] data) throws ExecutionException, InterruptedException {
|
||||
if(whisperContext==null){
|
||||
toastModelLoading();
|
||||
return null;
|
||||
}else{
|
||||
return whisperContext.transcribeData(data);
|
||||
}
|
||||
}
|
||||
|
||||
private static void toastModelLoading() {
|
||||
Aop.get(Handler.class).post(()->{
|
||||
ToastUtils.showShort("please wait for model loading");
|
||||
});
|
||||
}
|
||||
|
||||
public List<WhisperSegment> transcribeDataWithTime(float[] audioData) throws ExecutionException, InterruptedException {
|
||||
if(whisperContext==null){
|
||||
toastModelLoading();
|
||||
return null;
|
||||
}else{
|
||||
return whisperContext.transcribeDataWithTime(audioData);
|
||||
}
|
||||
}
|
||||
|
||||
public void init() {
|
||||
//noting to do.but init
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.task;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.os.Handler;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class LoadModelTask extends ThreadUtils.Task<Object> {
|
||||
private final TextView tv;
|
||||
public LoadModelTask(TextView tv) {
|
||||
this.tv = tv;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doInBackground() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
Aop.get(WhisperService.class).loadModel(tv);
|
||||
}else{
|
||||
Aop.get(Handler.class).post(()->{
|
||||
tv.append("not supported android devices");
|
||||
});
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(Object result) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFail(Throwable t) {
|
||||
}
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.task;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.Build;
|
||||
import android.widget.TextView;
|
||||
|
||||
import com.blankj.utilcode.util.ThreadUtils;
|
||||
import com.litongjava.jfinal.aop.Aop;
|
||||
import com.litongjava.whisper.android.java.services.WhisperService;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class TranscriptionTask extends ThreadUtils.Task<Object> {
|
||||
private final TextView tv;
|
||||
private final File sampleFile;
|
||||
|
||||
public TranscriptionTask(TextView tv, File sampleFile) {
|
||||
this.tv = tv;
|
||||
this.sampleFile = sampleFile;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object doInBackground() {
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
|
||||
Aop.get(WhisperService.class).transcribeSample(tv, sampleFile);
|
||||
}else{
|
||||
tv.append("not supported android devices");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(Object result) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onCancel() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFail(Throwable t) {
|
||||
}
|
||||
}
|
@ -1,91 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.utils;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
|
||||
public class AssetUtils {
|
||||
private static Logger log = LoggerFactory.getLogger(AssetUtils.class);
|
||||
|
||||
public static File copyFileIfNotExists(Context context, File distDir, String filename) {
|
||||
File dstFile = new File(distDir, filename);
|
||||
if (dstFile.exists()) {
|
||||
return dstFile;
|
||||
} else {
|
||||
File parentFile = dstFile.getParentFile();
|
||||
log.info("parentFile:{}", parentFile);
|
||||
if (!parentFile.exists()) {
|
||||
parentFile.mkdirs();
|
||||
}
|
||||
AssetUtils.copyFileFromAssets(context, filename, dstFile);
|
||||
}
|
||||
return dstFile;
|
||||
}
|
||||
|
||||
public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
|
||||
if (srcDir.isEmpty() || dstDir.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (!new File(dstDir).exists()) {
|
||||
new File(dstDir).mkdirs();
|
||||
}
|
||||
for (String fileName : appCtx.getAssets().list(srcDir)) {
|
||||
String srcSubPath = srcDir + File.separator + fileName;
|
||||
String dstSubPath = dstDir + File.separator + fileName;
|
||||
if (new File(srcSubPath).isDirectory()) {
|
||||
copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
|
||||
} else {
|
||||
copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
|
||||
File dstFile = new File(dstPath);
|
||||
copyFileFromAssets(appCtx, srcPath, dstFile);
|
||||
}
|
||||
|
||||
public static void copyFileFromAssets(Context appCtx, String srcPath, File dstFile) {
|
||||
if (srcPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
InputStream is = null;
|
||||
OutputStream os = null;
|
||||
try {
|
||||
is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
|
||||
|
||||
os = new BufferedOutputStream(new FileOutputStream(dstFile));
|
||||
byte[] buffer = new byte[1024];
|
||||
int length = 0;
|
||||
while ((length = is.read(buffer)) != -1) {
|
||||
os.write(buffer, 0, length);
|
||||
}
|
||||
} catch (FileNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
try {
|
||||
os.close();
|
||||
is.close();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
package com.litongjava.whisper.android.java.utils;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.ShortBuffer;
|
||||
|
||||
public class WaveEncoder {
|
||||
|
||||
public static float[] decodeWaveFile(File file) throws IOException {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
try (FileInputStream fis = new FileInputStream(file)) {
|
||||
byte[] buffer = new byte[1024];
|
||||
int bytesRead;
|
||||
while ((bytesRead = fis.read(buffer)) != -1) {
|
||||
baos.write(buffer, 0, bytesRead);
|
||||
}
|
||||
}
|
||||
ByteBuffer byteBuffer = ByteBuffer.wrap(baos.toByteArray());
|
||||
byteBuffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
int channel = byteBuffer.getShort(22);
|
||||
byteBuffer.position(44);
|
||||
|
||||
ShortBuffer shortBuffer = byteBuffer.asShortBuffer();
|
||||
short[] shortArray = new short[shortBuffer.limit()];
|
||||
shortBuffer.get(shortArray);
|
||||
|
||||
float[] output = new float[shortArray.length / channel];
|
||||
|
||||
for (int index = 0; index < output.length; index++) {
|
||||
if (channel == 1) {
|
||||
output[index] = Math.max(-1f, Math.min(1f, shortArray[index] / 32767.0f));
|
||||
} else {
|
||||
output[index] = Math.max(-1f, Math.min(1f, (shortArray[2 * index] + shortArray[2 * index + 1]) / 32767.0f / 2.0f));
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
public static void encodeWaveFile(File file, short[] data) throws IOException {
|
||||
try (FileOutputStream fos = new FileOutputStream(file)) {
|
||||
fos.write(headerBytes(data.length * 2));
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(data.length * 2);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
buffer.asShortBuffer().put(data);
|
||||
|
||||
byte[] bytes = new byte[buffer.limit()];
|
||||
buffer.get(bytes);
|
||||
|
||||
fos.write(bytes);
|
||||
}
|
||||
}
|
||||
|
||||
private static byte[] headerBytes(int totalLength) {
|
||||
if (totalLength < 44)
|
||||
throw new IllegalArgumentException("Total length must be at least 44 bytes");
|
||||
|
||||
ByteBuffer buffer = ByteBuffer.allocate(44);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
buffer.put((byte) 'R');
|
||||
buffer.put((byte) 'I');
|
||||
buffer.put((byte) 'F');
|
||||
buffer.put((byte) 'F');
|
||||
|
||||
buffer.putInt(totalLength - 8);
|
||||
|
||||
buffer.put((byte) 'W');
|
||||
buffer.put((byte) 'A');
|
||||
buffer.put((byte) 'V');
|
||||
buffer.put((byte) 'E');
|
||||
|
||||
buffer.put((byte) 'f');
|
||||
buffer.put((byte) 'm');
|
||||
buffer.put((byte) 't');
|
||||
buffer.put((byte) ' ');
|
||||
|
||||
buffer.putInt(16);
|
||||
buffer.putShort((short) 1);
|
||||
buffer.putShort((short) 1);
|
||||
buffer.putInt(16000);
|
||||
buffer.putInt(32000);
|
||||
buffer.putShort((short) 2);
|
||||
buffer.putShort((short) 16);
|
||||
|
||||
buffer.put((byte) 'd');
|
||||
buffer.put((byte) 'a');
|
||||
buffer.put((byte) 't');
|
||||
buffer.put((byte) 'a');
|
||||
|
||||
buffer.putInt(totalLength - 44);
|
||||
buffer.position(0);
|
||||
|
||||
byte[] bytes = new byte[buffer.limit()];
|
||||
buffer.get(bytes);
|
||||
|
||||
return bytes;
|
||||
}
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class CpuInfo {
|
||||
private static final String LOG_TAG = "WhisperCpuConfig";
|
||||
|
||||
private List<String> lines;
|
||||
|
||||
public CpuInfo(List<String> lines) {
|
||||
this.lines = lines;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public int getHighPerfCpuCount0() {
|
||||
try {
|
||||
return getHighPerfCpuCountByFrequencies();
|
||||
} catch (Exception e) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e);
|
||||
return getHighPerfCpuCountByVariant();
|
||||
}
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int getHighPerfCpuCountByFrequencies() {
|
||||
List<Integer> frequencies = getCpuValues("processor", line -> {
|
||||
try {
|
||||
return getMaxCpuFrequency(Integer.parseInt(line.trim()));
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
);
|
||||
Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): " + binnedValues(frequencies));
|
||||
return countDroppingMin(frequencies);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int getHighPerfCpuCountByVariant() {
|
||||
List<Integer> variants = getCpuValues("CPU variant", line -> Integer.parseInt(line.trim().substring(line.indexOf("0x") + 2), 16));
|
||||
Log.d(LOG_TAG, "Binned cpu variants (variant, count): " + binnedValues(variants));
|
||||
return countKeepingMin(variants);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private Map<Integer, Integer> binnedValues(List<Integer> values) {
|
||||
Map<Integer, Integer> countMap = new HashMap<>();
|
||||
for (int value : values) {
|
||||
countMap.put(value, countMap.getOrDefault(value, 0) + 1);
|
||||
}
|
||||
return countMap;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private List<Integer> getCpuValues(String property, Mapper mapper) {
|
||||
List<Integer> values = new ArrayList<>();
|
||||
for (String line : lines) {
|
||||
if (line.startsWith(property)) {
|
||||
values.add(mapper.map(line.substring(line.indexOf(':') + 1)));
|
||||
}
|
||||
}
|
||||
values.sort(Integer::compareTo);
|
||||
return values;
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int countDroppingMin(List<Integer> values) {
|
||||
int min = values.stream().mapToInt(i -> i).min().orElse(Integer.MAX_VALUE);
|
||||
return (int) values.stream().filter(value -> value > min).count();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
private int countKeepingMin(List<Integer> values) {
|
||||
int min = values.stream().mapToInt(i -> i).min().orElse(Integer.MAX_VALUE);
|
||||
return (int) values.stream().filter(value -> value.equals(min)).count();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public static int getHighPerfCpuCount() {
|
||||
try {
|
||||
return readCpuInfo().getHighPerfCpuCount0();
|
||||
} catch (Exception e) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU info", e);
|
||||
return Math.max(Runtime.getRuntime().availableProcessors() - 4, 0);
|
||||
}
|
||||
}
|
||||
|
||||
private static CpuInfo readCpuInfo() throws IOException {
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader("/proc/cpuinfo"))) {
|
||||
List<String> lines = new ArrayList<>();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
lines.add(line);
|
||||
}
|
||||
return new CpuInfo(lines);
|
||||
}
|
||||
}
|
||||
|
||||
private static int getMaxCpuFrequency(int cpuIndex) throws IOException {
|
||||
String path = "/sys/devices/system/cpu/cpu" + cpuIndex + "/cpufreq/cpuinfo_max_freq";
|
||||
try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
|
||||
return Integer.parseInt(reader.readLine());
|
||||
}
|
||||
}
|
||||
|
||||
private interface Mapper {
|
||||
int map(String line);
|
||||
}
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import com.litongjava.whisper.android.java.bean.WhisperSegment;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public class WhisperContext {
|
||||
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
private long ptr;
|
||||
private final ExecutorService executorService;
|
||||
|
||||
private WhisperContext(long ptr) {
|
||||
this.ptr = ptr;
|
||||
this.executorService = Executors.newSingleThreadExecutor();
|
||||
}
|
||||
|
||||
public String transcribeData(float[] data) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(new Callable<String>() {
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
public String call() throws Exception {
|
||||
if (ptr == 0L) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
int numThreads = WhisperCpuConfig.getPreferredThreadCount();
|
||||
Log.d(LOG_TAG, "Selecting " + numThreads + " threads");
|
||||
|
||||
StringBuilder result = new StringBuilder();
|
||||
synchronized (this) {
|
||||
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data);
|
||||
int textCount = WhisperLib.getTextSegmentCount(ptr);
|
||||
for (int i = 0; i < textCount; i++) {
|
||||
String sentence = WhisperLib.getTextSegment(ptr, i);
|
||||
result.append(sentence);
|
||||
}
|
||||
}
|
||||
return result.toString();
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
public List<WhisperSegment> transcribeDataWithTime(float[] data) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(new Callable<List<WhisperSegment>>() {
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
@Override
|
||||
public List<WhisperSegment> call() throws Exception {
|
||||
if (ptr == 0L) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
int numThreads = WhisperCpuConfig.getPreferredThreadCount();
|
||||
Log.d(LOG_TAG, "Selecting " + numThreads + " threads");
|
||||
|
||||
List<WhisperSegment> segments = new ArrayList<>();
|
||||
synchronized (this) {
|
||||
// StringBuilder result = new StringBuilder();
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data);
|
||||
int textCount = WhisperLib.getTextSegmentCount(ptr);
|
||||
for (int i = 0; i < textCount; i++) {
|
||||
long start = WhisperLib.getTextSegmentT0(ptr, i);
|
||||
String sentence = WhisperLib.getTextSegment(ptr, i);
|
||||
long end = WhisperLib.getTextSegmentT1(ptr, i);
|
||||
// result.append();
|
||||
segments.add(new WhisperSegment(start, end, sentence));
|
||||
|
||||
}
|
||||
// return result.toString();
|
||||
}
|
||||
return segments;
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public String benchMemory(int nthreads) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(() -> WhisperLib.benchMemcpy(nthreads)).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public String benchGgmlMulMat(int nthreads) throws ExecutionException, InterruptedException {
|
||||
return executorService.submit(() -> WhisperLib.benchGgmlMulMat(nthreads)).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public void release() throws ExecutionException, InterruptedException {
|
||||
executorService.submit(() -> {
|
||||
if (ptr != 0L) {
|
||||
WhisperLib.freeContext(ptr);
|
||||
ptr = 0;
|
||||
}
|
||||
}).get();
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromFile(String filePath) {
|
||||
long ptr = WhisperLib.initContext(filePath);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context with path " + filePath);
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromInputStream(InputStream stream) {
|
||||
long ptr = WhisperLib.initContextFromInputStream(stream);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context from input stream");
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static WhisperContext createContextFromAsset(AssetManager assetManager, String assetPath) {
|
||||
long ptr = WhisperLib.initContextFromAsset(assetManager, assetPath);
|
||||
if (ptr == 0L) {
|
||||
throw new RuntimeException("Couldn't create context from asset " + assetPath);
|
||||
}
|
||||
return new WhisperContext(ptr);
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static String getSystemInfo() {
|
||||
return WhisperLib.getSystemInfo();
|
||||
}
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
public class WhisperCpuConfig {
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public static int getPreferredThreadCount() {
|
||||
return Math.max(CpuInfo.getHighPerfCpuCount(), 2);
|
||||
}
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.InputStream;
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public class WhisperLib {
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
|
||||
static {
|
||||
|
||||
Log.d(LOG_TAG, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
|
||||
boolean loadVfpv4 = false;
|
||||
boolean loadV8fp16 = false;
|
||||
if (WhisperUtils.isArmEabiV7a()) {
|
||||
String cpuInfo = WhisperUtils.cpuInfo();
|
||||
if (cpuInfo != null) {
|
||||
Log.d(LOG_TAG, "CPU info: " + cpuInfo);
|
||||
if (cpuInfo.contains("vfpv4")) {
|
||||
Log.d(LOG_TAG, "CPU supports vfpv4");
|
||||
loadVfpv4 = true;
|
||||
}
|
||||
}
|
||||
} else if (WhisperUtils.isArmEabiV8a()) {
|
||||
String cpuInfo = WhisperUtils.cpuInfo();
|
||||
if (cpuInfo != null) {
|
||||
Log.d(LOG_TAG, "CPU info: " + cpuInfo);
|
||||
if (cpuInfo.contains("fphp")) {
|
||||
Log.d(LOG_TAG, "CPU supports fp16 arithmetic");
|
||||
loadV8fp16 = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (loadVfpv4) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so");
|
||||
System.loadLibrary("whisper_vfpv4");
|
||||
} else if (loadV8fp16) {
|
||||
Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so");
|
||||
System.loadLibrary("whisper_v8fp16_va");
|
||||
} else {
|
||||
Log.d(LOG_TAG, "Loading libwhisper.so");
|
||||
System.loadLibrary("whisper");
|
||||
}
|
||||
}
|
||||
|
||||
public static native long initContextFromInputStream(InputStream inputStream);
|
||||
|
||||
public static native long initContextFromAsset(AssetManager assetManager, String assetPath);
|
||||
|
||||
public static native long initContext(String modelPath);
|
||||
|
||||
public static native void freeContext(long contextPtr);
|
||||
|
||||
public static native void fullTranscribe(long contextPtr, int numThreads, float[] audioData);
|
||||
|
||||
public static native int getTextSegmentCount(long contextPtr);
|
||||
|
||||
public static native String getTextSegment(long contextPtr, int index);
|
||||
|
||||
public static native long getTextSegmentT0(long contextPtr, int index);
|
||||
|
||||
public static native long getTextSegmentT1(long contextPtr, int index);
|
||||
|
||||
public static native String getSystemInfo();
|
||||
|
||||
public static native String benchMemcpy(int nthread);
|
||||
|
||||
public static native String benchGgmlMulMat(int nthread);
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package com.whispercpp.java.whisper;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class WhisperUtils {
|
||||
private static final String LOG_TAG = "LibWhisper";
|
||||
|
||||
|
||||
public static boolean isArmEabiV7a() {
|
||||
return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a");
|
||||
}
|
||||
|
||||
public static boolean isArmEabiV8a() {
|
||||
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
|
||||
}
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.O)
|
||||
public static String cpuInfo() {
|
||||
try {
|
||||
Path path = new File("/proc/cpuinfo").toPath();
|
||||
return new String(java.nio.file.Files.readAllBytes(path));
|
||||
} catch (Exception e) {
|
||||
Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,56 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
|
||||
project(whisper.cpp)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
|
||||
|
||||
set(
|
||||
SOURCE_FILES
|
||||
${WHISPER_LIB_DIR}/ggml.c
|
||||
${WHISPER_LIB_DIR}/ggml-alloc.c
|
||||
${WHISPER_LIB_DIR}/ggml-backend.c
|
||||
${WHISPER_LIB_DIR}/ggml-quants.c
|
||||
${WHISPER_LIB_DIR}/whisper.cpp
|
||||
${CMAKE_SOURCE_DIR}/jni.c
|
||||
)
|
||||
|
||||
find_library(LOG_LIB log)
|
||||
|
||||
function(build_library target_name)
|
||||
add_library(
|
||||
${target_name}
|
||||
SHARED
|
||||
${SOURCE_FILES}
|
||||
)
|
||||
|
||||
target_link_libraries(${target_name} ${LOG_LIB} android)
|
||||
|
||||
if (${target_name} STREQUAL "whisper_v8fp16_va")
|
||||
target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16)
|
||||
elseif (${target_name} STREQUAL "whisper_vfpv4")
|
||||
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
|
||||
endif ()
|
||||
|
||||
if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
|
||||
|
||||
target_compile_options(${target_name} PRIVATE -O3)
|
||||
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
|
||||
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
|
||||
|
||||
#target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
|
||||
#target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
|
||||
#target_link_options(${target_name} PRIVATE -flto)
|
||||
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
build_library("whisper") # Default target
|
||||
|
||||
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
|
||||
build_library("whisper_v8fp16_va")
|
||||
elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a")
|
||||
build_library("whisper_vfpv4")
|
||||
endif ()
|
||||
|
||||
include_directories(${WHISPER_LIB_DIR})
|
@ -1,257 +0,0 @@
|
||||
#include <jni.h>
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/sysinfo.h>
|
||||
#include <string.h>
|
||||
#include "whisper.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define TAG "JNI"
|
||||
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)
|
||||
|
||||
static inline int min(int a, int b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
static inline int max(int a, int b) {
|
||||
return (a > b) ? a : b;
|
||||
}
|
||||
|
||||
struct input_stream_context {
|
||||
size_t offset;
|
||||
JNIEnv * env;
|
||||
jobject thiz;
|
||||
jobject input_stream;
|
||||
|
||||
jmethodID mid_available;
|
||||
jmethodID mid_read;
|
||||
};
|
||||
|
||||
size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;
|
||||
|
||||
jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);
|
||||
|
||||
jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);
|
||||
|
||||
if (size_to_copy != read_size || size_to_copy != n_read) {
|
||||
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
|
||||
}
|
||||
|
||||
jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
|
||||
memcpy(output, byte_array_elements, size_to_copy);
|
||||
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);
|
||||
|
||||
(*is->env)->DeleteLocalRef(is->env, byte_array);
|
||||
|
||||
is->offset += size_to_copy;
|
||||
|
||||
return size_to_copy;
|
||||
}
|
||||
bool inputStreamEof(void * ctx) {
|
||||
struct input_stream_context* is = (struct input_stream_context*)ctx;
|
||||
|
||||
jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
|
||||
return result <= 0;
|
||||
}
|
||||
void inputStreamClose(void * ctx) {
|
||||
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContextFromInputStream(
|
||||
JNIEnv *env, jobject thiz, jobject input_stream) {
|
||||
UNUSED(thiz);
|
||||
|
||||
struct whisper_context *context = NULL;
|
||||
struct whisper_model_loader loader = {};
|
||||
struct input_stream_context inp_ctx = {};
|
||||
|
||||
inp_ctx.offset = 0;
|
||||
inp_ctx.env = env;
|
||||
inp_ctx.thiz = thiz;
|
||||
inp_ctx.input_stream = input_stream;
|
||||
|
||||
jclass cls = (*env)->GetObjectClass(env, input_stream);
|
||||
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
|
||||
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");
|
||||
|
||||
loader.context = &inp_ctx;
|
||||
loader.read = inputStreamRead;
|
||||
loader.eof = inputStreamEof;
|
||||
loader.close = inputStreamClose;
|
||||
|
||||
loader.eof(loader.context);
|
||||
|
||||
context = whisper_init(&loader);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
static size_t asset_read(void *ctx, void *output, size_t read_size) {
|
||||
return AAsset_read((AAsset *) ctx, output, read_size);
|
||||
}
|
||||
|
||||
static bool asset_is_eof(void *ctx) {
|
||||
return AAsset_getRemainingLength64((AAsset *) ctx) <= 0;
|
||||
}
|
||||
|
||||
static void asset_close(void *ctx) {
|
||||
AAsset_close((AAsset *) ctx);
|
||||
}
|
||||
|
||||
static struct whisper_context *whisper_init_from_asset(
|
||||
JNIEnv *env,
|
||||
jobject assetManager,
|
||||
const char *asset_path
|
||||
) {
|
||||
LOGI("Loading model from asset '%s'\n", asset_path);
|
||||
AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager);
|
||||
AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING);
|
||||
if (!asset) {
|
||||
LOGW("Failed to open '%s'\n", asset_path);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
whisper_model_loader loader = {
|
||||
.context = asset,
|
||||
.read = &asset_read,
|
||||
.eof = &asset_is_eof,
|
||||
.close = &asset_close
|
||||
};
|
||||
|
||||
return whisper_init(&loader);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContextFromAsset(
|
||||
JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL);
|
||||
context = whisper_init_from_asset(env, assetManager, asset_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_initContext(
|
||||
JNIEnv *env, jobject thiz, jstring model_path_str) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = NULL;
|
||||
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
|
||||
context = whisper_init_from_file(model_path_chars);
|
||||
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
|
||||
return (jlong) context;
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_freeContext(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
whisper_free(context);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_fullTranscribe(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||
|
||||
// The below adapted from the Objective-C iOS sample
|
||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
params.print_realtime = true;
|
||||
params.print_progress = false;
|
||||
params.print_timestamps = true;
|
||||
params.print_special = false;
|
||||
params.translate = false;
|
||||
params.language = "en";
|
||||
params.n_threads = num_threads;
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = false;
|
||||
|
||||
whisper_reset_timings(context);
|
||||
|
||||
LOGI("About to run whisper_full");
|
||||
if (whisper_full(context, params, audio_data_arr, audio_data_length) != 0) {
|
||||
LOGI("Failed to run the model");
|
||||
} else {
|
||||
whisper_print_timings(context);
|
||||
}
|
||||
(*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT);
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentCount(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr) {
|
||||
UNUSED(env);
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
return whisper_full_n_segments(context);
|
||||
}
|
||||
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegment(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const char *text = whisper_full_get_segment_text(context, index);
|
||||
jstring string = (*env)->NewStringUTF(env, text);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentT0(JNIEnv *env, jobject thiz,jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const int64_t t0 = whisper_full_get_segment_t0(context, index);
|
||||
return (jlong)t0;
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getTextSegmentT1(JNIEnv *env, jobject thiz,jlong context_ptr, jint index) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
const int64_t t1 = whisper_full_get_segment_t1(context, index);
|
||||
return (jlong)t1;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_getSystemInfo(
|
||||
JNIEnv *env, jobject thiz
|
||||
) {
|
||||
UNUSED(thiz);
|
||||
const char *sysinfo = whisper_print_system_info();
|
||||
jstring string = (*env)->NewStringUTF(env, sysinfo);
|
||||
return string;
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_benchMemcpy(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy);
|
||||
}
|
||||
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_whispercpp_java_whisper_WhisperLib_benchGgmlMulMat(JNIEnv *env, jobject thiz,
|
||||
jint n_threads) {
|
||||
UNUSED(thiz);
|
||||
const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads);
|
||||
jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat);
|
||||
}
|
||||
|
@ -1,30 +0,0 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="85.84757"
|
||||
android:endY="92.4963"
|
||||
android:startX="42.9492"
|
||||
android:startY="49.59793"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||
android:strokeWidth="1"
|
||||
android:strokeColor="#00000000" />
|
||||
</vector>
|
@ -1,170 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
@ -1,57 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:orientation="vertical"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content">
|
||||
|
||||
<Button
|
||||
android:id="@+id/systemInfoBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="System Info" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/loadModelBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Load model" />
|
||||
|
||||
</LinearLayout>
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content">
|
||||
|
||||
<Button
|
||||
android:id="@+id/transcriptSampleBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Transcribe sample" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/clearBtn"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Clear" />
|
||||
</LinearLayout>
|
||||
|
||||
<TextView
|
||||
android:id="@+id/sample_text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Hello World!"
|
||||
app:layout_constraintBottom_toBottomOf="parent"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintRight_toRightOf="parent"
|
||||
app:layout_constraintTop_toTopOf="parent"
|
||||
android:scrollbarAlwaysDrawHorizontalTrack="true"
|
||||
android:maxLines="999"/>
|
||||
|
||||
</LinearLayout>
|
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
@ -1,5 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
Before Width: | Height: | Size: 3.5 KiB |
Before Width: | Height: | Size: 5.2 KiB |
Before Width: | Height: | Size: 2.6 KiB |
Before Width: | Height: | Size: 3.3 KiB |
Before Width: | Height: | Size: 4.8 KiB |
Before Width: | Height: | Size: 7.3 KiB |
Before Width: | Height: | Size: 7.7 KiB |
Before Width: | Height: | Size: 12 KiB |
Before Width: | Height: | Size: 10 KiB |
Before Width: | Height: | Size: 16 KiB |
@ -1,16 +0,0 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.Whisperandroidjava" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_200</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/black</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_200</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
@ -1,3 +0,0 @@
|
||||
<resources>
|
||||
<string name="app_name">whisper.android.java</string>
|
||||
</resources>
|
@ -1,16 +0,0 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.Whisperandroidjava" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_500</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/white</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_700</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor" tools:targetApi="l">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
@ -1,17 +0,0 @@
|
||||
package com.litongjava.whisper.android.java;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
|
||||
*/
|
||||
public class ExampleUnitTest {
|
||||
@Test
|
||||
public void addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2);
|
||||
}
|
||||
}
|
@ -1,24 +0,0 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
buildscript {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
dependencies {
|
||||
classpath "com.android.tools.build:gradle:4.1.3"
|
||||
|
||||
// NOTE: Do not place your application dependencies here; they belong
|
||||
// in the individual module build.gradle files
|
||||
}
|
||||
}
|
||||
|
||||
allprojects {
|
||||
repositories {
|
||||
google()
|
||||
jcenter()
|
||||
}
|
||||
}
|
||||
|
||||
task clean(type: Delete) {
|
||||
delete rootProject.buildDir
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app"s APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Automatically convert third-party libraries to use AndroidX
|
||||
android.enableJetifier=true
|
@ -1,6 +0,0 @@
|
||||
#Fri Oct 20 11:07:15 HST 2023
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-all.zip
|
172
examples/whisper.android.java/gradlew
vendored
@ -1,172 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS=""
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin, switch paths to Windows format before running java
|
||||
if $cygwin ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=$((i+1))
|
||||
done
|
||||
case $i in
|
||||
(0) set -- ;;
|
||||
(1) set -- "$args0" ;;
|
||||
(2) set -- "$args0" "$args1" ;;
|
||||
(3) set -- "$args0" "$args1" "$args2" ;;
|
||||
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=$(save "$@")
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
|
||||
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
|
||||
cd "$(dirname "$0")"
|
||||
fi
|
||||
|
||||
exec "$JAVACMD" "$@"
|
84
examples/whisper.android.java/gradlew.bat
vendored
@ -1,84 +0,0 @@
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS=
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto init
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:init
|
||||
@rem Get command-line arguments, handling Windows variants
|
||||
|
||||
if not "%OS%" == "Windows_NT" goto win9xME_args
|
||||
|
||||
:win9xME_args
|
||||
@rem Slurp the command line arguments.
|
||||
set CMD_LINE_ARGS=
|
||||
set _SKIP=2
|
||||
|
||||
:win9xME_args_slurp
|
||||
if "x%~1" == "x" goto execute
|
||||
|
||||
set CMD_LINE_ARGS=%*
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
@ -1,2 +0,0 @@
|
||||
include ':app'
|
||||
rootProject.name = "whisper.android.java"
|
2
examples/whisper.android/.idea/gradle.xml
generated
@ -4,7 +4,6 @@
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
|
||||
<option name="modules">
|
||||
@ -14,7 +13,6 @@
|
||||
</set>
|
||||
</option>
|
||||
<option name="resolveExternalAnnotations" value="false" />
|
||||
<option name="resolveModulePerSourceSet" value="false" />
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
|
@ -17,12 +17,12 @@ else
|
||||
encoder_only=$2
|
||||
fi
|
||||
|
||||
models=( \
|
||||
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
||||
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
||||
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
|
||||
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" "medium-dis" \
|
||||
"large-v2" "large-v2-q4_0" "large-v2-q4_1" "large-v2-q5_0" "large-v2-q5_1" "large-v2-q8_0" "large-v2-dis" \
|
||||
models=( \
|
||||
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
||||
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
|
||||
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
|
||||
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
|
||||
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
|
||||
)
|
||||
|
||||
if [ "$encoder_only" -eq 0 ]; then
|
||||
@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
|
||||
printf "\n"
|
||||
fi
|
||||
|
||||
printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
# actual run
|
||||
@ -56,7 +56,6 @@ for model in "${models[@]}"; do
|
||||
# parse the output:
|
||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
||||
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
||||
batchd_time=$(echo "$output" | grep "batchd time" | awk '{print $11}')
|
||||
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
||||
system_info=$(echo "$output" | grep "system_info")
|
||||
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
||||
@ -84,10 +83,6 @@ for model in "${models[@]}"; do
|
||||
config="$config COREML"
|
||||
fi
|
||||
|
||||
if [[ $system_info == *"CUDA = 1"* ]]; then
|
||||
config="$config CUDA"
|
||||
fi
|
||||
|
||||
if [[ $system_info == *"METAL = 1"* ]]; then
|
||||
config="$config METAL"
|
||||
fi
|
||||
@ -95,6 +90,6 @@ for model in "${models[@]}"; do
|
||||
commit=$(git rev-parse --short HEAD)
|
||||
|
||||
if [ $ret -eq 0 ]; then
|
||||
printf "| <todo> | <todo> | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
||||
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
fi
|
||||
done
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
python3 models/convert-pt-to-ggml.py ~/.cache/whisper/$model.pt ../whisper models/
|
||||
|
@ -15,13 +15,33 @@ declare -a filedex
|
||||
cd `dirname $0`
|
||||
cd ../
|
||||
|
||||
for i in `ls ./models | grep ^ggml-.*.bin | grep -v "\-q"`; do
|
||||
m="models/$i"
|
||||
if [ -f "$m" ]; then
|
||||
if [ "${m##*.}" == "bin" ]; then
|
||||
./quantize "${m}" "${m::${#m}-4}-${qtype1}.bin" ${qtype1};
|
||||
./quantize "${m}" "${m::${#m}-4}-${qtype0}.bin" ${qtype0};
|
||||
filedex+=( "${m::${#m}-4}-${qtype1}.bin" "${m::${#m}-4}-${qtype0}.bin" )
|
||||
# Let's loop across all the objects in the 'models' dir:
|
||||
for i in ./models/*; do
|
||||
# Check to see if it's a file or directory
|
||||
if [ -d "$i" ]; then
|
||||
# It's a directory! We should make sure it's not empty first:
|
||||
if [ "$(ls -A $i)" ]; then
|
||||
# Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
|
||||
for f in "$i"/*.bin; do
|
||||
# [Neuron Activation]
|
||||
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
|
||||
if [ "$newfile" != "q5" ]; then
|
||||
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
|
||||
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
|
||||
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
|
||||
fi
|
||||
done
|
||||
fi
|
||||
else
|
||||
# It's a file! Let's make sure it's the right type:
|
||||
if [ "${i##*.}" == "bin" ]; then
|
||||
# And we probably want to skip the testing files
|
||||
if [ "${i:9:8}" != "for-test" ]; then
|
||||
# [Neuron Activation]
|
||||
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
|
||||
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
|
||||
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
23
ggml-alloc.c
@ -446,14 +446,12 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
|
||||
return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
|
||||
}
|
||||
|
||||
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
|
||||
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view) {
|
||||
ggml_tallocr_t alloc = node_tallocr(galloc, view);
|
||||
|
||||
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
|
||||
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
|
||||
if (update_backend) {
|
||||
view->backend = view->view_src->backend;
|
||||
}
|
||||
view->backend = view->view_src->backend;
|
||||
view->buffer = view->view_src->buffer;
|
||||
view->data = (char *)view->view_src->data + view->view_offs;
|
||||
|
||||
@ -471,7 +469,7 @@ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
|
||||
|
||||
if (node->data == NULL) {
|
||||
if (ggml_is_view(node)) {
|
||||
init_view(galloc, node, true);
|
||||
init_view(galloc, node);
|
||||
} else {
|
||||
// see if we can reuse a parent's buffer (inplace)
|
||||
if (ggml_op_can_inplace(node->op)) {
|
||||
@ -501,14 +499,15 @@ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
|
||||
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
|
||||
node->view_src = view_src;
|
||||
view_src_hn->n_views += 1;
|
||||
init_view(galloc, node, false);
|
||||
init_view(galloc, node);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
|
||||
node->view_src = parent;
|
||||
p_hn->n_views += 1;
|
||||
init_view(galloc, node, false);
|
||||
init_view(galloc, node);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -538,7 +537,7 @@ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
hash_get(galloc, view_src)->n_views += 1;
|
||||
if (node->buffer == NULL && node->data != NULL) {
|
||||
// view of a pre-allocated tensor, didn't call init_view() yet
|
||||
init_view(galloc, node, true);
|
||||
init_view(galloc, node);
|
||||
}
|
||||
}
|
||||
|
||||
@ -549,7 +548,7 @@ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
}
|
||||
hash_get(galloc, parent)->n_children += 1;
|
||||
if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
|
||||
init_view(galloc, parent, true);
|
||||
init_view(galloc, parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -664,7 +663,7 @@ size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, st
|
||||
return max_size;
|
||||
}
|
||||
|
||||
void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
|
||||
void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_alloct) {
|
||||
const size_t hash_size = hash_set.size;
|
||||
|
||||
GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
|
||||
@ -687,7 +686,7 @@ void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * grap
|
||||
// reset hash values
|
||||
memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
|
||||
|
||||
galloc->hash_allocs = hash_node_talloc;
|
||||
galloc->hash_allocs = hash_node_alloct;
|
||||
|
||||
ggml_tallocr_alloc_graph_impl(galloc, graph);
|
||||
|
||||
|
535
ggml-cuda.cu
@ -17,12 +17,7 @@ extern "C" {
|
||||
|
||||
#define GGML_CUDA_MAX_DEVICES 16
|
||||
|
||||
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
|
||||
GGML_API void ggml_init_cublas(void);
|
||||
|
||||
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
|
||||
GGML_API bool ggml_cublas_loaded(void);
|
||||
|
||||
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
||||
GGML_API void ggml_cuda_host_free(void * ptr);
|
||||
|
||||
|
@ -26,7 +26,7 @@
|
||||
#include <stdbool.h>
|
||||
|
||||
// max memory buffers that can be mapped to the device
|
||||
#define GGML_METAL_MAX_BUFFERS 64
|
||||
#define GGML_METAL_MAX_BUFFERS 16
|
||||
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
||||
|
||||
struct ggml_tensor;
|
||||
@ -52,11 +52,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx);
|
||||
void * ggml_metal_host_malloc(size_t n);
|
||||
void ggml_metal_host_free (void * data);
|
||||
|
||||
// helper to check if the device supports a specific family
|
||||
// ideally, the user code should be doing these checks
|
||||
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||
bool ggml_metal_supports_family(struct ggml_metal_context * ctx, int family);
|
||||
|
||||
// set the number of command buffers to use
|
||||
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
|
||||
|
||||
@ -105,8 +100,6 @@ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||
|
||||
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
||||
|
||||
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
119
ggml-metal.m
@ -86,7 +86,6 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
||||
@ -115,7 +114,6 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||
@ -128,7 +126,7 @@ struct ggml_metal_context {
|
||||
// MSL code
|
||||
// TODO: move the contents here when ready
|
||||
// for now it is easier to work in a separate file
|
||||
//static NSString * const msl_library_source = @"see metal.metal";
|
||||
static NSString * const msl_library_source = @"see metal.metal";
|
||||
|
||||
// Here to assist with NSBundle Path Hack
|
||||
@interface GGMLMetalClass : NSObject
|
||||
@ -144,8 +142,7 @@ void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_dat
|
||||
ggml_metal_log_user_data = user_data;
|
||||
}
|
||||
|
||||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||
static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
||||
static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
|
||||
if (ggml_metal_log_callback != NULL) {
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
@ -290,7 +287,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
||||
@ -321,7 +317,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||
@ -340,15 +335,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
||||
if ([ctx->device supportsFamily:i]) {
|
||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
||||
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
if (ctx->device.maxTransferRate != 0) {
|
||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||
} else {
|
||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
||||
}
|
||||
@ -391,7 +386,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
||||
@ -422,7 +416,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||
@ -459,10 +452,6 @@ void ggml_metal_host_free(void * data) {
|
||||
free(data);
|
||||
}
|
||||
|
||||
bool ggml_metal_supports_family(struct ggml_metal_context * ctx, int family) {
|
||||
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||
}
|
||||
|
||||
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
||||
}
|
||||
@ -484,10 +473,6 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
||||
|
||||
const int64_t tsize = ggml_nbytes(t);
|
||||
|
||||
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
|
||||
ctx = t->buffer->backend->context;
|
||||
}
|
||||
|
||||
// find the view that contains the tensor fully
|
||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||
@ -545,11 +530,11 @@ bool ggml_metal_add_buffer(
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1e6);
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1e6);
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
|
||||
++ctx->n_buffers;
|
||||
} else {
|
||||
@ -569,11 +554,11 @@ bool ggml_metal_add_buffer(
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1e6);
|
||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1e6, i);
|
||||
GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
||||
if (i + size_step < size) {
|
||||
GGML_METAL_LOG_INFO("\n");
|
||||
}
|
||||
@ -584,16 +569,16 @@ bool ggml_metal_add_buffer(
|
||||
|
||||
#if TARGET_OS_OSX
|
||||
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
||||
ctx->device.currentAllocatedSize / 1e6,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
||||
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
|
||||
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
||||
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||
GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
||||
} else {
|
||||
GGML_METAL_LOG_INFO("\n");
|
||||
}
|
||||
#else
|
||||
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1e6);
|
||||
GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1076,7 +1061,7 @@ void ggml_metal_graph_compute(
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
const unsigned int gqa = ne12/ne02;
|
||||
const uint gqa = ne12/ne02;
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
@ -1154,7 +1139,6 @@ void ggml_metal_graph_compute(
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
||||
nrows = 4;
|
||||
} break;
|
||||
@ -1162,18 +1146,13 @@ void ggml_metal_graph_compute(
|
||||
{
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
if (src1t == GGML_TYPE_F32) {
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||
nrows = 4;
|
||||
}
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||
nrows = 4;
|
||||
}
|
||||
} break;
|
||||
@ -1485,58 +1464,6 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_IM2COL:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||
|
||||
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
||||
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int32_t IW = src1->ne[0];
|
||||
|
||||
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const int32_t KW = src0->ne[0];
|
||||
|
||||
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int32_t OW = dst->ne[1];
|
||||
|
||||
const int32_t CHW = IC * KH * KW;
|
||||
|
||||
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
||||
default: GGML_ASSERT(false);
|
||||
};
|
||||
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
@ -1755,9 +1682,3 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
|
||||
ggml_metal_set_n_cb(ctx, n_cb);
|
||||
}
|
||||
|
||||
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
||||
|
||||
return ggml_metal_supports_family(ctx, family);
|
||||
}
|
||||
|
108
ggml-metal.metal
@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F32_F32;
|
||||
@ -844,79 +844,6 @@ kernel void kernel_mul_mv_f32_f32(
|
||||
}
|
||||
}
|
||||
|
||||
#define N_F16_F16 4
|
||||
|
||||
kernel void kernel_mul_mv_f16_f16(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F16_F16;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (half) x[i] * (half) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *)x;
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||
device const half4 * y4 = (device const half4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@ -1302,39 +1229,6 @@ kernel void kernel_rope(
|
||||
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||
|
||||
kernel void kernel_im2col_f16(
|
||||
device const float * x,
|
||||
device half * dst,
|
||||
constant int32_t & ofs0,
|
||||
constant int32_t & ofs1,
|
||||
constant int32_t & IW,
|
||||
constant int32_t & IH,
|
||||
constant int32_t & CHW,
|
||||
constant int32_t & s0,
|
||||
constant int32_t & s1,
|
||||
constant int32_t & p0,
|
||||
constant int32_t & p1,
|
||||
constant int32_t & d0,
|
||||
constant int32_t & d1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
||||
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
||||
|
||||
const int32_t offset_dst =
|
||||
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
||||
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
|
@ -1368,12 +1368,7 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f
|
||||
float max = x[0];
|
||||
float sum_w = weights[0];
|
||||
float sum_x = sum_w * x[0];
|
||||
#ifdef HAVE_BUGGY_APPLE_LINKER
|
||||
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
|
||||
for (volatile int i = 1; i < n; ++i) {
|
||||
#else
|
||||
for (int i = 1; i < n; ++i) {
|
||||
#endif
|
||||
if (x[i] < min) min = x[i];
|
||||
if (x[i] > max) max = x[i];
|
||||
float w = weights[i];
|
||||
|