Compare commits

..

1 Commits

Author SHA1 Message Date
18d5ff8695 Added GitHub workflow for deb package build 2023-05-06 11:04:04 +01:00
111 changed files with 5124 additions and 47281 deletions

View File

@ -1,41 +1,31 @@
name: CI
on: [push, pull_request]
env:
ubuntu_image: "ubuntu:22.04"
jobs:
ubuntu-latest:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
- name: Dependencies
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential libsdl2-dev
make
make stream'
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install libsdl2-dev
- name: Build
run: |
make
make stream
macOS-latest:
runs-on: macOS-latest
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Dependencies
run: |
@ -47,104 +37,82 @@ jobs:
make
make stream
freeBSD-latest:
runs-on: macos-12
steps:
- name: Clone
uses: actions/checkout@v3
- name: Build
uses: cross-platform-actions/action@v0.15.0
with:
operating_system: freebsd
version: '13.2'
run: |
sudo pkg update
sudo pkg install -y gmake sdl2
gmake
gmake stream
ubuntu-latest-gcc:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
- name: Dependencies
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ctest -L gh --output-on-failure'
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
- name: Build
run: |
make
ctest -L gh --output-on-failure
ubuntu-latest-clang:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
build: [Debug, Release]
arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le]
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
- name: Dependencies
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
make
ctest -L gh --output-on-failure'
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
sudo apt-get install libsdl2-dev
- name: Configure
run: cmake . -DWHISPER_SUPPORT_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
- name: Build
run: |
make
ctest -L gh --output-on-failure
ubuntu-latest-gcc-sanitized:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
sanitizer: [ADDRESS, THREAD, UNDEFINED]
arch: [linux/amd64]
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Build ${{ matrix.arch }}
- name: Dependencies
run: |
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
apt update
apt install -y build-essential cmake
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
make
ctest -L gh --output-on-failure'
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get install cmake
- name: Configure
run: cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
- name: Build
run: |
make
ctest -L gh --output-on-failure
windows:
runs-on: windows-latest
@ -157,16 +125,14 @@ jobs:
include:
- arch: Win32
s2arc: x86
jnaPath: win32-x86
- arch: x64
s2arc: x64
jnaPath: win32-x86-64
- sdl2: ON
s2ver: 2.26.0
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
@ -193,12 +159,6 @@ jobs:
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload dll
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.jnaPath }}_whisper.dll
path: build/bin/${{ matrix.build }}/whisper.dll
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
@ -227,7 +187,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
@ -293,7 +253,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v1
@ -340,16 +300,24 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Setup emsdk
uses: mymindstorm/setup-emsdk@v12
- name: Dependencies
run: |
wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz
tar -xvf master.tar.gz
emsdk-master/emsdk update
emsdk-master/emsdk install latest
emsdk-master/emsdk activate latest
- name: Verify
run: emcc -v
- name: Configure
run: echo "tmp"
- name: Build
run: |
pushd emsdk-master
source ./emsdk_env.sh
popd
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
@ -362,12 +330,10 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Configure
run: |
cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
mkdir models/ggml-base.en-encoder.mlmodelc
run: cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
@ -380,7 +346,7 @@ jobs:
steps:
- name: Clone
uses: actions/checkout@v3
uses: actions/checkout@v1
- name: Install Java
uses: actions/setup-java@v3
@ -395,55 +361,3 @@ jobs:
run: |
cd examples/whisper.android
./gradlew assembleRelease --no-daemon
java:
needs: [ 'windows' ]
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- name: Install Java
uses: actions/setup-java@v1
with:
java-version: 17
- name: Download Windows lib
uses: actions/download-artifact@v3
with:
name: win32-x86-64_whisper.dll
path: bindings/java/build/generated/resources/main/win32-x86-64
- name: Build
run: |
models\download-ggml-model.cmd tiny.en
cd bindings/java
chmod +x ./gradlew
./gradlew build
- name: Upload jar
uses: actions/upload-artifact@v3
with:
name: whispercpp.jar
path: bindings/java/build/libs/whispercpp-*.jar
- name: Publish package
if: ${{ github.ref == 'refs/heads/master' }}
uses: gradle/gradle-build-action@v2
with:
arguments: publish
env:
MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
quantize:
runs-on: ubuntu-latest
steps:
- name: Clone
uses: actions/checkout@v3
- name: Test quantize
run: |
./models/download-ggml-model.sh tiny.en
make quantize
./quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0

64
.github/workflows/release-deb.yml vendored Normal file
View File

@ -0,0 +1,64 @@
name: release-deb
on:
release:
types: [created]
jobs:
build:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Configure
run: |
set -x -e
VERSION=$(echo $GITHUB_REF | cut --delimiter=/ -f 3)
ID="whisper-cpp-small_${VERSION}_amd64"
echo "PKG_VERSION=$VERSION" >> $GITHUB_ENV
echo "PKG_ID=$ID" >> $GITHUB_ENV
- name: Install deps
run: |
sudo apt install -y --no-install-recommends intel-mkl
- name: Build
run: |
cmake -S . -B build-release -D BUILD_SHARED_LIBS=OFF
cd build-release
make
cd ..
- name: Create package tree
env:
GITHUB_REPO: ${{ github.repository }}
run: |
export ROOT=$PKG_ID/opt/project/whisper.cpp
mkdir -p $ROOT/bin
mkdir -p $ROOT/share
mkdir -p $PKG_ID/DEBIAN
cp build-release/bin/main $ROOT/bin/whisper
cp -r contrib/debian/control $PKG_ID/DEBIAN/
echo "Version: $PKG_VERSION" >> $PKG_ID/DEBIAN/control
echo "Git-Repo: $GITHUB_REPO" >> $PKG_ID/DEBIAN/control
echo "Git-Commit: $GITHUB_SHA" >> $PKG_ID/DEBIAN/control
models/download-ggml-model.sh small
build-release/bin/quantize models/ggml-small.bin \
$ROOT/share/ggml-small-q5_1.bin q5_1
- name: Create deb package
run: |
mkdir artifacts
dpkg-deb --build --root-owner-group $PKG_ID
- name: Upload Release Asset
uses: xresloader/upload-to-github-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
release_id: ${{ github.event.release.id }}
file: ${{ env.PKG_ID }}.deb

6
.gitignore vendored
View File

@ -5,6 +5,7 @@
.test/
.vs/
.vscode/
.idea/
.DS_Store
build/
@ -16,6 +17,7 @@ build-cublas/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
cmake-build-debug/
/main
/stream
@ -24,7 +26,6 @@ build-sanitize-thread/
/talk-llama
/bench
/quantize
/lsp
arm_neon.h
sync.sh
@ -42,6 +43,3 @@ extra/bench-gg.txt
models/*.mlmodel
models/*.mlmodelc
models/*.mlpackage
bindings/java/.gradle/
bindings/java/.idea/
.idea/

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.0)
project(whisper.cpp VERSION 1.4.2)
project(whisper.cpp VERSION 1.4.1)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -54,19 +54,14 @@ option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
option(WHISPER_NO_F16C "whisper: disable F16c" OFF)
option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
if (APPLE)
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
else()
option(WHISPER_BLAS "whisper: use BLAS libraries" OFF)
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
option(WHISPER_OPENBLAS "whisper: support for OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
endif()
option(WHISPER_PERF "whisper: enable perf timings" OFF)
@ -126,47 +121,25 @@ if (APPLE)
endif()
if (WHISPER_COREML_ALLOW_FALLBACK)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_COREML_ALLOW_FALLBACK)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML_ALLOW_FALLBACK)
endif()
endif()
endif()
if (WHISPER_OPENBLAS)
set(WHISPER_BLAS_VENDOR "OpenBLAS")
set(WHISPER_BLAS ON)
find_library(OPENBLAS_LIB
NAMES openblas libopenblas
)
if (OPENBLAS_LIB)
message(STATUS "OpenBLAS found")
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${OPENBLAS_LIB})
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
else()
message(WARNING "OpenBLAS not found")
endif()
endif()
if (WHISPER_BLAS)
if (WIN32)
if(DEFINED ENV{OPENBLAS_PATH})
set(BLAS_LIBRARIES $ENV{OPENBLAS_PATH}/lib/libopenblas.dll.a)
message(STATUS "Libraries ${BLAS_LIBRARIES}")
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
include_directories($ENV{OPENBLAS_PATH}/include)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else ()
message(WARNING "BLAS library was not found. Environment variable OPENBLAS_PATH not defined.")
endif ()
else ()
set(BLA_STATIC 1)
set(BLA_VENDOR ${WHISPER_BLAS_VENDOR})
# set(BLA_PREFER_PKGCONFIG 1)
set(BLA_SIZEOF_INTEGER 8)
find_package(BLAS)
if(BLAS_FOUND)
message(STATUS "BLAS compatible library found")
message(STATUS "Libraries ${BLAS_LIBRARIES}")
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include/openblas /usr/local/include/openblas $ENV{BLAS_HOME}/include)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_OPENBLAS)
include_directories(${BLAS_INCLUDE_DIRS})
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${BLAS_LIBRARIES})
else()
message(WARNING "BLAS library was not found")
endif()
endif ()
endif ()
if (WHISPER_CUBLAS)
cmake_minimum_required(VERSION 3.17)
@ -192,43 +165,12 @@ if (WHISPER_CUBLAS)
endif()
endif()
if (WHISPER_HIPBLAS)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()
find_package(hip)
find_package(hipblas)
find_package(rocblas)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
if (WHISPER_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
endif()
if (WHISPER_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h)
add_compile_definitions(GGML_USE_CLBLAST)
@ -238,10 +180,6 @@ if (WHISPER_CLBLAST)
endif()
endif()
if( WHISPER_OPENVINO )
find_package(OpenVINO REQUIRED COMPONENTS Runtime)
endif()
# compiler flags
if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
@ -281,25 +219,12 @@ message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
message(STATUS "PowerPC detected")
else()
message(STATUS "x86 detected")
if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /utf-8")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /utf-8")
if(NOT WHISPER_NO_AVX2)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if(NOT WHISPER_NO_AVX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX")
endif()
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /arch:AVX2")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2")
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
@ -352,24 +277,6 @@ if (WHISPER_COREML)
)
endif()
if (WHISPER_OPENVINO)
set(TARGET whisper.openvino)
add_library(${TARGET} OBJECT
openvino/whisper-openvino-encoder.h
openvino/whisper-openvino-encoder.cpp
)
target_include_directories(${TARGET} PUBLIC
.
)
set_property(TARGET ${TARGET} PROPERTY POSITION_INDEPENDENT_CODE ON)
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_OPENVINO)
target_link_libraries(${TARGET} PRIVATE openvino::runtime)
endif()
#
# whisper - this is the main library of the project
#
@ -395,10 +302,6 @@ if (WHISPER_COREML)
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
endif()
if (WHISPER_OPENVINO)
target_link_libraries(${TARGET} PRIVATE whisper.openvino)
endif()
if (MSVC)
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})

148
Makefile
View File

@ -12,12 +12,6 @@ ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
ifndef NVCC_VERSION
ifeq ($(call,$(shell which nvcc))$(.SHELLSTATUS),0)
NVCC_VERSION := $(shell nvcc --version | egrep -o "V[0-9]+.[0-9]+.[0-9]+" | cut -c2-)
endif
endif
CCV := $(shell $(CC) --version | head -n 1)
CXXV := $(shell $(CXX) --version | head -n 1)
@ -48,16 +42,21 @@ ifneq ($(wildcard /usr/include/musl/*),)
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
endif
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
# and on macOS its availability depends on enabling Darwin extensions
ifeq ($(UNAME_S),Darwin)
CFLAGS += -D_DARWIN_C_SOURCE
CXXFLAGS += -D_DARWIN_C_SOURCE
endif
# OS specific
# TODO: support Windows
ifeq ($(filter $(UNAME_S),Linux Darwin DragonFly FreeBSD NetBSD OpenBSD Haiku),$(UNAME_S))
ifeq ($(UNAME_S),Linux)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),Darwin)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),FreeBSD)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
ifeq ($(UNAME_S),Haiku)
CFLAGS += -pthread
CXXFLAGS += -pthread
endif
@ -67,50 +66,60 @@ endif
# feel free to update the Makefile for your architecture and send a pull request or issue
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifeq ($(UNAME_S),Darwin)
CPUINFO_CMD := sysctl machdep.cpu.features
else ifeq ($(UNAME_S),Linux)
CPUINFO_CMD := cat /proc/cpuinfo
else ifneq (,$(filter MINGW32_NT% MINGW64_NT%,$(UNAME_S)))
CPUINFO_CMD := cat /proc/cpuinfo
else ifeq ($(UNAME_S),Haiku)
CPUINFO_CMD := sysinfo -cpu
endif
ifdef CPUINFO_CMD
AVX_M := $(shell $(CPUINFO_CMD) | grep -m 1 "avx ")
ifneq (,$(findstring avx,$(AVX_M)))
CFLAGS += -mf16c
AVX1_M := $(shell sysctl machdep.cpu.features)
ifneq (,$(findstring FMA,$(AVX1_M)))
CFLAGS += -mfma
endif
ifneq (,$(findstring AVX1.0,$(AVX1_M)))
CFLAGS += -mavx
endif
AVX2_M := $(shell $(CPUINFO_CMD) | grep -m 1 "avx2 ")
AVX2_M := $(shell sysctl machdep.cpu.leaf7_features)
ifneq (,$(findstring AVX2,$(AVX2_M)))
CFLAGS += -mavx2
endif
else ifeq ($(UNAME_S),Linux)
AVX2_M := $(shell grep "avx2 " /proc/cpuinfo)
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
endif
FMA_M := $(shell $(CPUINFO_CMD) | grep -m 1 "fma ")
FMA_M := $(shell grep "fma " /proc/cpuinfo)
ifneq (,$(findstring fma,$(FMA_M)))
CFLAGS += -mfma
endif
F16C_M := $(shell $(CPUINFO_CMD) | grep -m 1 "f16c ")
F16C_M := $(shell grep "f16c " /proc/cpuinfo)
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
AVX1_M := $(shell $(CPUINFO_CMD) | grep -m 1 "avx ")
AVX1_M := $(shell grep "avx " /proc/cpuinfo)
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
SSE3_M := $(shell $(CPUINFO_CMD) | grep -m 1 "sse3 ")
SSE3_M := $(shell grep "sse3 " /proc/cpuinfo)
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
SSSE3_M := $(shell $(CPUINFO_CMD) | grep -m 1 "ssse3 ")
ifneq (,$(findstring ssse3,$(SSSE3_M)))
CFLAGS += -mssse3
else ifeq ($(UNAME_S),Haiku)
AVX2_M := $(shell sysinfo -cpu | grep "AVX2 ")
ifneq (,$(findstring avx2,$(AVX2_M)))
CFLAGS += -mavx2
endif
FMA_M := $(shell sysinfo -cpu | grep "FMA ")
ifneq (,$(findstring fma,$(FMA_M)))
CFLAGS += -mfma
endif
F16C_M := $(shell sysinfo -cpu | grep "F16C ")
ifneq (,$(findstring f16c,$(F16C_M)))
CFLAGS += -mf16c
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))
CFLAGS += -mavx
endif
endif
else
CFLAGS += -mfma -mf16c -mavx -mavx2
endif
endif
ifeq ($(UNAME_M),amd64)
@ -146,56 +155,29 @@ endif
endif
ifdef WHISPER_OPENBLAS
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
LDFLAGS += -lopenblas
endif
ifdef WHISPER_CUBLAS
ifeq ($(shell expr $(NVCC_VERSION) \>= 11.6), 1)
CUDA_ARCH_FLAG=native
else
CUDA_ARCH_FLAG=all
endif
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif
ifdef WHISPER_HIPBLAS
ROCM_PATH ?= /opt/rocm
HIPCC ?= $(ROCM_PATH)/bin/hipcc
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
LDFLAGS += -lhipblas -lamdhip64 -lrocblas
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
WHISPER_OBJ += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
endif
ifdef WHISPER_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST
LDFLAGS += -lclblast
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -framework OpenCL
else
LDFLAGS += -lOpenCL
endif
LDFLAGS += -lclblast -lOpenCL
WHISPER_OBJ += ggml-opencl.o
ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
$(CXX) $(CXXFLAGS) -c $< -o $@
ggml-opencl.o: ggml-opencl.c ggml-opencl.h
$(CC) $(CFLAGS) -c $< -o $@
endif
ifdef WHISPER_GPROF
@ -258,7 +240,7 @@ ifndef WHISPER_COREML
WHISPER_OBJ += whisper.o
else
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder.mm -o whisper-encoder.o
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
@ -273,7 +255,7 @@ libwhisper.so: ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
clean:
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
rm -f *.o main stream command talk talk-llama bench quantize libwhisper.a libwhisper.so
#
# Examples
@ -300,9 +282,6 @@ stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHIS
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
@ -322,19 +301,12 @@ samples:
@wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
@wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
@wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav
@wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3
@wget --quiet --show-progress -O samples/diffusion2023-07-03.flac https://archive.org/download/diffusion2023-07-03/diffusion2023-07-03.flac
@echo "Converting to 16-bit WAV ..."
@ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
@ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
@ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
@rm samples/*.ogg
@ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav
@rm samples/mm1.wav
@ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav
@rm samples/a13.mp3
@ffmpeg -loglevel -0 -y -i samples/diffusion2023-07-03.flac -ar 16000 -ac 1 -c:a pcm_s16le samples/diffusion2023-07-03.wav
@rm samples/diffusion2023-07-03.flac
#
# Models
@ -376,4 +348,4 @@ tiny.en tiny base.en base small.en small medium.en medium large-v1 large: main
.PHONY: tests
tests:
bash ./tests/run-tests.sh $(word 2, $(MAKECMDGOALS))
bash ./tests/run-tests.sh

141
README.md
View File

@ -6,7 +6,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Beta: [v1.4.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.1) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -21,8 +21,6 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp
- Runs on the CPU
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
- [OpenVINO Support](https://github.com/ggerganov/whisper.cpp#openvino-support)
- [C-style API](https://github.com/ggerganov/whisper.cpp/blob/master/whisper.h)
Supported platforms:
@ -30,7 +28,6 @@ Supported platforms:
- [x] Mac OS (Intel and Arm)
- [x] [iOS](examples/whisper.objc)
- [x] [Android](examples/whisper.android)
- [x] [Java](bindings/java/README.md)
- [x] Linux / [FreeBSD](https://github.com/ggerganov/whisper.cpp/issues/56#issuecomment-1350920264)
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
@ -61,7 +58,7 @@ Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD
intrinsics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
instrisics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since
the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Quick start
@ -74,8 +71,6 @@ Then, download one of the Whisper models converted in [ggml format](models). For
bash ./models/download-ggml-model.sh base.en
```
If you wish to convert the Whisper models to ggml format yourself, instructions are in [models/README.md](models/README.md).
Now build the [main](examples/main) example and transcribe an audio file like this:
```bash
@ -116,7 +111,6 @@ options:
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-tr, --translate [false ] translate from source language to english
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-di, --diarize [false ] stereo audio diarization
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
@ -265,12 +259,6 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
pip install coremltools
```
- To ensure `coremltools` operates correctly, please confirm that [Xcode](https://developer.apple.com/xcode/) is installed and execute `xcode-select --install` to install the command-line tools.
- Python 3.10 is recommended.
- [OPTIONAL] It is recommended to utilize a Python version management system, such as [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for this step:
- To create an environment, use: `conda create -n py310-whisper python=3.10 -y`
- To activate the environment, use: `conda activate py310-whisper`
- Generate a Core ML model. For example, to generate a `base.en` model, use:
```bash
@ -312,88 +300,9 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
For more information about the Core ML implementation please refer to PR [#566](https://github.com/ggerganov/whisper.cpp/pull/566).
## OpenVINO support
On platforms that support [OpenVINO](https://github.com/openvinotoolkit/openvino), the Encoder inference can be executed
on OpenVINO-supported devices including x86 CPUs and Intel GPUs (integrated & discrete).
This can result in significant speedup in encoder performance. Here are the instructions for generating the OpenVINO model and using it with `whisper.cpp`:
- First, setup python virtual env. and install python dependencies. Python 3.10 is recommended.
Windows:
```
cd models
python -m venv openvino_conv_env
openvino_conv_env\Scripts\activate
python -m pip install --upgrade pip
pip install -r openvino-conversion-requirements.txt
```
Linux and macOS:
```
cd models
python3 -m venv openvino_conv_env
source openvino_conv_env/bin/activate
python -m pip install --upgrade pip
pip install -r openvino-conversion-requirements.txt
```
- Generate an OpenVINO encoder model. For example, to generate a `base.en` model, use:
```
python convert-whisper-to-openvino.py --model base.en
```
This will produce ggml-base.en-encoder-openvino.xml/.bin IR model files. It's recommended to relocate these to the same folder as ggml models, as that
is the default location that the OpenVINO extension will search at runtime.
- Build `whisper.cpp` with OpenVINO support:
Download OpenVINO package from [release page](https://github.com/openvinotoolkit/openvino/releases). The recommended version to use is [2023.0.0](https://github.com/openvinotoolkit/openvino/releases/tag/2023.0.0).
After downloading & extracting package onto your development system, set up required environment by sourcing setupvars script. For example:
Linux:
```bash
source /path/to/l_openvino_toolkit_ubuntu22_2023.0.0.10926.b4452d56304_x86_64/setupvars.sh
```
Windows (cmd):
```
C:\Path\To\w_openvino_toolkit_windows_2023.0.0.10926.b4452d56304_x86_64\setupvars.bat
```
And then build the project using cmake:
```bash
cd build
cmake -DWHISPER_OPENVINO=1 ..
```
- Run the examples as usual. For example:
```bash
./main -m models/ggml-base.en.bin -f samples/jfk.wav
...
whisper_ctx_init_openvino_encoder: loading OpenVINO model from 'models/ggml-base.en-encoder-openvino.xml'
whisper_ctx_init_openvino_encoder: first run on a device may take a while ...
whisper_openvino_init: path_model = models/ggml-base.en-encoder-openvino.xml, device = GPU, cache_dir = models/ggml-base.en-encoder-openvino-cache
whisper_ctx_init_openvino_encoder: OpenVINO model loaded
system_info: n_threads = 4 / 8 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | OPENVINO = 1 |
...
```
The first time run on an OpenVINO device is slow, since the OpenVINO framework will compile the IR (Intermediate Representation) model to a device-specific 'blob'. This device-specific blob will get
cached for the next run.
For more information about the Core ML implementation please refer to PR [#1037](https://github.com/ggerganov/whisper.cpp/pull/1037).
## NVIDIA GPU support via cuBLAS
With NVIDIA cards the Encoder processing can to a large extent be offloaded to the GPU through cuBLAS.
With NVIDIA cards, the Encoder processing can be offloaded to the GPU to a large extend through cuBLAS.
First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-downloads
Now build `whisper.cpp` with cuBLAS support:
@ -405,7 +314,7 @@ WHISPER_CUBLAS=1 make -j
## OpenCL GPU support via CLBlast
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APUs or low end devices for up to ~2x speedup.
For cards and integrated GPUs that support OpenCL, the Encoder processing can be largely offloaded to the GPU through CLBlast. This is especially useful for users with AMD APU's or low end devices for up to ~2x speedup.
First, make sure you have installed `CLBlast` for your OS or Distribution: https://github.com/CNugteren/CLBlast
@ -428,18 +337,6 @@ cp bin/* ../
Run all the examples as usual.
## BLAS CPU support via OpenBLAS
Encoder processing can be accelerated on the CPU via OpenBLAS.
First, make sure you have installed `openblas`: https://www.openblas.net/
Now build `whisper.cpp` with OpenBLAS support:
```
make clean
WHISPER_OPENBLAS=1 make -j
```
## Limitations
- Inference only
@ -574,7 +471,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
[00:00:10.020 --> 00:00:11.000] country.
```
## Word-level timestamp (experimental)
## Word-level timestamp
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
@ -615,32 +512,6 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
[00:00:10.510 --> 00:00:11.000] .
```
## Speaker segmentation via tinydiarize (experimental)
More information about this approach is available here: https://github.com/ggerganov/whisper.cpp/pull/1058
Sample usage:
```py
# download a tinydiarize compatible model
./models/download-ggml-model.sh small.en-tdrz
# run as usual, adding the "-tdrz" command-line argument
./main -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
...
main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, tdrz = 1, timestamps = 1 ...
...
[00:00:00.000 --> 00:00:03.800] Okay Houston, we've had a problem here. [SPEAKER_TURN]
[00:00:03.800 --> 00:00:06.200] This is Houston. Say again please. [SPEAKER_TURN]
[00:00:06.200 --> 00:00:08.260] Uh Houston we've had a problem.
[00:00:08.260 --> 00:00:11.320] We've had a main beam up on a volt. [SPEAKER_TURN]
[00:00:11.320 --> 00:00:13.820] Roger main beam interval. [SPEAKER_TURN]
[00:00:13.820 --> 00:00:15.100] Uh uh [SPEAKER_TURN]
[00:00:15.100 --> 00:00:18.020] So okay stand, by thirteen we're looking at it. [SPEAKER_TURN]
[00:00:18.020 --> 00:00:25.740] Okay uh right now uh Houston the uh voltage is uh is looking good um.
[00:00:27.620 --> 00:00:29.940] And we had a a pretty large bank or so.
```
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
@ -724,8 +595,6 @@ in [models](models).
- [X] Javascript: [bindings/javascript](bindings/javascript) | [#309](https://github.com/ggerganov/whisper.cpp/discussions/309)
- React Native (iOS / Android): [whisper.rn](https://github.com/mybigday/whisper.rn)
- [X] Go: [bindings/go](bindings/go) | [#312](https://github.com/ggerganov/whisper.cpp/discussions/312)
- [X] Java:
- [GiviMAD/whisper-jni](https://github.com/GiviMAD/whisper-jni)
- [X] Ruby: [bindings/ruby](bindings/ruby) | [#507](https://github.com/ggerganov/whisper.cpp/discussions/507)
- [X] Objective-C / Swift: [ggerganov/whisper.spm](https://github.com/ggerganov/whisper.spm) | [#313](https://github.com/ggerganov/whisper.cpp/discussions/313)
- [exPHAT/SwiftWhisper](https://github.com/exPHAT/SwiftWhisper)

View File

@ -32,7 +32,7 @@ mkdir:
modtidy:
@go mod tidy
clean:
clean:
@echo Clean
@rm -fr $(BUILD_DIR)
@go clean

View File

@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil, nil); err != nil {
if err := context.Process(samples, nil); err != nil {
return err
}
@ -71,7 +71,7 @@ The examples are placed in the `build` directory. Once built, you can download a
And you can then test a model against samples with the following command:
```bash
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
./build/go-whisper -model models/ggml-tiny.en.bin samples/jfk.wav
```
## Using the bindings

View File

@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb, nil); err != nil {
if err := context.Process(data, cb); err != nil {
return err
}

View File

@ -19,10 +19,6 @@ func (p *Params) SetTranslate(v bool) {
p.translate = toBool(v)
}
func (p *Params) SetSplitOnWord(v bool) {
p.split_on_word = toBool(v)
}
func (p *Params) SetNoContext(v bool) {
p.no_context = toBool(v)
}

View File

@ -81,10 +81,6 @@ func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
func (context *context) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v)
}
// Set number of threads to use
func (context *context) SetThreads(v uint) {
context.params.SetThreads(int(v))
@ -97,7 +93,7 @@ func (context *context) SetOffset(v time.Duration) {
// Set duration of audio to process
func (context *context) SetDuration(v time.Duration) {
context.params.SetDuration(int(v.Milliseconds()))
context.params.SetOffset(int(v.Milliseconds()))
}
// Set timestamp token probability threshold (~0.01)
@ -156,16 +152,12 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
}
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
func (context *context) Process(data []float32, cb SegmentCallback) error {
if context.model.ctx == nil {
return ErrInternalAppError
}
// If the callback is defined then we force on single_segment mode
if callNewSegment != nil {
if cb != nil {
context.params.SetSingleSegment(true)
}
@ -173,28 +165,24 @@ func (context *context) Process(
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if callNewSegment != nil {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
cb(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if callNewSegment != nil {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
cb(toSegment(context.model.ctx, i))
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
return err
}

View File

@ -12,10 +12,6 @@ import (
// time. It is called during the Process function
type SegmentCallback func(Segment)
// ProgressCallback is the callback function for reporting progress during
// processing. It is called during the Process function
type ProgressCallback func(int)
// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
@ -42,7 +38,6 @@ type Context interface {
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
@ -52,7 +47,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback, ProgressCallback) error
Process([]float32, SegmentCallback) error
// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.

View File

@ -15,7 +15,6 @@ import (
#include <stdlib.h>
extern void callNewSegment(void* user_data, int new);
extern void callProgress(void* user_data, int progress);
extern bool callEncoderBegin(void* user_data);
// Text segment callback
@ -27,15 +26,6 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
}
}
// Progress callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
if(user_data != NULL && ctx != NULL) {
callProgress(user_data, progress);
}
}
// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
@ -53,8 +43,6 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
params.new_segment_callback_user_data = (void*)(ctx);
params.encoder_begin_callback = whisper_encoder_begin_cb;
params.encoder_begin_callback_user_data = (void*)(ctx);
params.progress_callback = whisper_progress_cb;
params.progress_callback_user_data = (void*)(ctx);
return params;
}
*/
@ -270,13 +258,13 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token {
}
// Task tokens
func (ctx *Context) Whisper_token_translate() Token {
return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx)))
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
}
// Task tokens
func (ctx *Context) Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx)))
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
}
// Performance information
@ -302,19 +290,11 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full(
params Params,
samples []float32,
encoderBeginCallback func() bool,
newSegmentCallback func(int),
progressCallback func(int),
) error {
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
registerProgressCallback(ctx, progressCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
defer registerProgressCallback(ctx, nil)
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
@ -338,18 +318,6 @@ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, proc
}
}
// Return the id of the autodetected language, returns -1 if not found
// Added to whisper.cpp in
// https://github.com/ggerganov/whisper.cpp/commit/a1c1583cc7cd8b75222857afc936f0638c5683d6
//
// Examples:
//
// "de" -> 2
// "german" -> 2
func (ctx *Context) Whisper_full_lang_id() int {
return int(C.whisper_full_lang_id((*C.struct_whisper_context)(ctx)))
}
// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
@ -402,7 +370,6 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
var (
cbNewSegment = make(map[unsafe.Pointer]func(int))
cbProgress = make(map[unsafe.Pointer]func(int))
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
)
@ -414,14 +381,6 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
}
}
func registerProgressCallback(ctx *Context, fn func(int)) {
if fn == nil {
delete(cbProgress, unsafe.Pointer(ctx))
} else {
cbProgress[unsafe.Pointer(ctx)] = fn
}
}
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
if fn == nil {
delete(cbEncoderBegin, unsafe.Pointer(ctx))
@ -437,13 +396,6 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
}
}
//export callProgress
func callProgress(user_data unsafe.Pointer, progress C.int) {
if fn, ok := cbProgress[user_data]; ok {
fn(int(progress))
}
}
//export callEncoderBegin
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
if fn, ok := cbEncoderBegin[user_data]; ok {
@ -463,7 +415,3 @@ func (t TokenData) T0() int64 {
func (t TokenData) T1() int64 {
return int64(t.t1)
}
func (t TokenData) Id() Token {
return Token(t.id)
}

View File

@ -52,7 +52,7 @@ func Test_Whisper_001(t *testing.T) {
defer ctx.Whisper_free()
params := ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
data := buf.AsFloat32Buffer().Data
err = ctx.Whisper_full(params, data, nil, nil, nil)
err = ctx.Whisper_full(params, data, nil, nil)
assert.NoError(err)
// Print out tokens

View File

@ -1,124 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Palette2">
<group name="Swing">
<item class="com.intellij.uiDesigner.HSpacer" tooltip-text="Horizontal Spacer" icon="/com/intellij/uiDesigner/icons/hspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="1" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="com.intellij.uiDesigner.VSpacer" tooltip-text="Vertical Spacer" icon="/com/intellij/uiDesigner/icons/vspacer.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="1" anchor="0" fill="2" />
</item>
<item class="javax.swing.JPanel" icon="/com/intellij/uiDesigner/icons/panel.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3" />
</item>
<item class="javax.swing.JScrollPane" icon="/com/intellij/uiDesigner/icons/scrollPane.svg" removable="false" auto-create-binding="false" can-attach-label="true">
<default-constraints vsize-policy="7" hsize-policy="7" anchor="0" fill="3" />
</item>
<item class="javax.swing.JButton" icon="/com/intellij/uiDesigner/icons/button.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="0" fill="1" />
<initial-values>
<property name="text" value="Button" />
</initial-values>
</item>
<item class="javax.swing.JRadioButton" icon="/com/intellij/uiDesigner/icons/radioButton.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="RadioButton" />
</initial-values>
</item>
<item class="javax.swing.JCheckBox" icon="/com/intellij/uiDesigner/icons/checkBox.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="3" anchor="8" fill="0" />
<initial-values>
<property name="text" value="CheckBox" />
</initial-values>
</item>
<item class="javax.swing.JLabel" icon="/com/intellij/uiDesigner/icons/label.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="8" fill="0" />
<initial-values>
<property name="text" value="Label" />
</initial-values>
</item>
<item class="javax.swing.JTextField" icon="/com/intellij/uiDesigner/icons/textField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JPasswordField" icon="/com/intellij/uiDesigner/icons/passwordField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JFormattedTextField" icon="/com/intellij/uiDesigner/icons/formattedTextField.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1">
<preferred-size width="150" height="-1" />
</default-constraints>
</item>
<item class="javax.swing.JTextArea" icon="/com/intellij/uiDesigner/icons/textArea.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTextPane" icon="/com/intellij/uiDesigner/icons/textPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JEditorPane" icon="/com/intellij/uiDesigner/icons/editorPane.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JComboBox" icon="/com/intellij/uiDesigner/icons/comboBox.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="2" anchor="8" fill="1" />
</item>
<item class="javax.swing.JTable" icon="/com/intellij/uiDesigner/icons/table.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JList" icon="/com/intellij/uiDesigner/icons/list.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="2" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTree" icon="/com/intellij/uiDesigner/icons/tree.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3">
<preferred-size width="150" height="50" />
</default-constraints>
</item>
<item class="javax.swing.JTabbedPane" icon="/com/intellij/uiDesigner/icons/tabbedPane.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSplitPane" icon="/com/intellij/uiDesigner/icons/splitPane.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="3" hsize-policy="3" anchor="0" fill="3">
<preferred-size width="200" height="200" />
</default-constraints>
</item>
<item class="javax.swing.JSpinner" icon="/com/intellij/uiDesigner/icons/spinner.svg" removable="false" auto-create-binding="true" can-attach-label="true">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSlider" icon="/com/intellij/uiDesigner/icons/slider.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="8" fill="1" />
</item>
<item class="javax.swing.JSeparator" icon="/com/intellij/uiDesigner/icons/separator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="6" anchor="0" fill="3" />
</item>
<item class="javax.swing.JProgressBar" icon="/com/intellij/uiDesigner/icons/progressbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1" />
</item>
<item class="javax.swing.JToolBar" icon="/com/intellij/uiDesigner/icons/toolbar.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="6" anchor="0" fill="1">
<preferred-size width="-1" height="20" />
</default-constraints>
</item>
<item class="javax.swing.JToolBar$Separator" icon="/com/intellij/uiDesigner/icons/toolbarSeparator.svg" removable="false" auto-create-binding="false" can-attach-label="false">
<default-constraints vsize-policy="0" hsize-policy="0" anchor="0" fill="1" />
</item>
<item class="javax.swing.JScrollBar" icon="/com/intellij/uiDesigner/icons/scrollbar.svg" removable="false" auto-create-binding="true" can-attach-label="false">
<default-constraints vsize-policy="6" hsize-policy="0" anchor="0" fill="2" />
</item>
</group>
</component>
</project>

View File

@ -1,71 +0,0 @@
# Java JNI bindings for Whisper
This package provides Java JNI bindings for whisper.cpp. They have been tested on:
* <strike>Darwin (OS X) 12.6 on x64_64</strike>
* Ubuntu on x86_64
* Windows on x86_64
The "low level" bindings are in `WhisperCppJnaLibrary`. The most simple usage is as follows:
JNA will attempt to load the `whispercpp` shared library from:
- jna.library.path
- jna.platform.library
- ~/Library/Frameworks
- /Library/Frameworks
- /System/Library/Frameworks
- classpath
```java
import io.github.ggerganov.whispercpp.WhisperCpp;
public class Example {
public static void main(String[] args) {
WhisperCpp whisper = new WhisperCpp();
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
long context = whisper.initContext("base.en");
try {
var whisperParams = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// custom configuration if required
whisperParams.temperature_inc = 0f;
var samples = readAudio(); // divide each value by 32767.0f
whisper.fullTranscribe(whisperParams, samples);
int segmentCount = whisper.getTextSegmentCount(context);
for (int i = 0; i < segmentCount; i++) {
String text = whisper.getTextSegment(context, i);
System.out.println(segment.getText());
}
} finally {
whisper.freeContext(context);
}
}
}
```
## Building & Testing
In order to build, you need to have the JDK 8 or higher installed. Run the tests with:
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp/bindings/java
./gradlew build
```
You need to have the `whisper` library in your [JNA library path](https://java-native-access.github.io/jna/4.2.1/com/sun/jna/NativeLibrary.html). On Windows the dll is included in the jar and you can update it:
```bash
copy /y ..\..\build\bin\Release\whisper.dll build\generated\resources\main\win32-x86-64\whisper.dll
```
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -1,112 +0,0 @@
plugins {
id 'java'
id 'java-library'
id 'maven-publish'
}
archivesBaseName = 'whispercpp'
group = 'io.github.ggerganov'
version = '1.4.0'
sourceCompatibility = 1.8
targetCompatibility = 1.8
sourceSets {
main {
resources {
srcDirs = ['src/main/resources', 'build/generated/resources/main']
}
}
test {
runtimeClasspath += files('build/generated/resources/main')
}
}
tasks.register('copyLibwhisperDynlib', Copy) {
from '../../build'
include 'libwhisper.dynlib'
into 'build/generated/resources/main/darwin'
}
tasks.register('copyLibwhisperSo', Copy) {
from '../../build'
include 'libwhisper.so'
into 'build/generated/resources/main/linux-x86-64'
}
tasks.register('copyWhisperDll', Copy) {
from '../../build/Release'
include 'whisper.dll'
into 'build/generated/resources/main/windows-x86-64'
}
tasks.register('copyLibs') {
dependsOn copyLibwhisperDynlib, copyLibwhisperSo, copyWhisperDll
}
test {
systemProperty 'jna.library.path', project.file('build/generated/resources/main').absolutePath
}
java {
withSourcesJar()
withJavadocJar()
}
jar {
exclude '**/whisper_java.exp', '**/whisper_java.lib'
}
javadoc {
options.addStringOption('Xdoclint:none', '-quiet')
}
tasks.withType(Test) {
useJUnitPlatform()
}
dependencies {
implementation "net.java.dev.jna:jna:5.13.0"
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"
testImplementation "org.assertj:assertj-core:3.24.2"
}
repositories {
mavenCentral()
}
publishing {
publications {
mavenJava(MavenPublication) {
artifactId = 'whispercpp'
from components.java
pom {
name = 'whispercpp'
description = "Java JNA bindings for OpenAI's Whisper model, implemented in C/C++"
url = 'https://github.com/ggerganov/whisper.cpp'
licenses {
license {
name = 'MIT licence'
url = 'https://raw.githubusercontent.com/ggerganov/whisper.cpp/master/LICENSE'
}
}
developers {
developer {
id = 'ggerganov'
name = 'Georgi Gerganov'
email = 'ggerganov@gmail.com'
}
developer {
id = 'nalbion'
name = 'Nicholas Albion'
email = 'nalbion@yahoo.com'
}
}
scm {
connection = 'scm:git:git://github.com/ggerganov/whisper.cpp.git'
url = 'https://github.com/ggerganov/whisper.cpp'
}
}
}
}
}

View File

@ -1,6 +0,0 @@
org.gradle.jvmargs=-Xms256m -Xmx1024m
system.include.dir=/usr/include
#system.local.include.dir=../../include
system.local.include.dir=./build/generated/sources/headers/java/main
jni.include.dir=/usr/lib/jvm/java-8-openjdk-amd64/include/
jni.lib.dir=/usr/lib/jvm/java-8-openjdk-amd64/lib/

Binary file not shown.

View File

@ -1,6 +0,0 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

244
bindings/java/gradlew vendored
View File

@ -1,244 +0,0 @@
#!/bin/sh
#
# Copyright © 2015-2021 the original authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
#
# Gradle start up script for POSIX generated by Gradle.
#
# Important for running:
#
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
# noncompliant, but you have some other compliant shell such as ksh or
# bash, then to run this script, type that shell name before the whole
# command line, like:
#
# ksh Gradle
#
# Busybox and similar reduced shells will NOT work, because this script
# requires all of these POSIX shell features:
# * functions;
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
# * compound commands having a testable exit status, especially «case»;
# * various built-in commands including «command», «set», and «ulimit».
#
# Important for patching:
#
# (2) This script targets any POSIX shell, so it avoids extensions provided
# by Bash, Ksh, etc; in particular arrays are avoided.
#
# The "traditional" practice of packing multiple parameters into a
# space-separated string is a well documented source of bugs and security
# problems, so this is (mostly) avoided, by progressively accumulating
# options in "$@", and eventually passing that to Java.
#
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
# see the in-line comments for details.
#
# There are tweaks for specific operating systems such as AIX, CygWin,
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
#
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
warn () {
echo "$*"
} >&2
die () {
echo
echo "$*"
echo
exit 1
} >&2
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "$( uname )" in #(
CYGWIN* ) cygwin=true ;; #(
Darwin* ) darwin=true ;; #(
MSYS* | MINGW* ) msys=true ;; #(
NONSTOP* ) nonstop=true ;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
fi
# Collect all arguments for the java command, stacking in reverse order:
# * args from the command line
# * the main class name
# * -classpath
# * -D...appname settings
# * --module-path (only if needed)
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1
then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
#
# In Bash we could simply go:
#
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
# set -- "${ARGS[@]}" "$@"
#
# but POSIX shell has neither arrays nor command substitution, so instead we
# post-process each arg (as a line of input to sed) to backslash-escape any
# character that might be a shell metacharacter, then use eval to reverse
# that process (while maintaining the separation between arguments), and wrap
# the whole thing up as a single "set" statement.
#
# This will of course break if any of these variables contains a newline or
# an unmatched quote.
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@"

View File

@ -1,92 +0,0 @@
@rem
@rem Copyright 2015 the original author or authors.
@rem
@rem Licensed under the Apache License, Version 2.0 (the "License");
@rem you may not use this file except in compliance with the License.
@rem You may obtain a copy of the License at
@rem
@rem https://www.apache.org/licenses/LICENSE-2.0
@rem
@rem Unless required by applicable law or agreed to in writing, software
@rem distributed under the License is distributed on an "AS IS" BASIS,
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@rem See the License for the specific language governing permissions and
@rem limitations under the License.
@rem
@if "%DEBUG%"=="" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
:end
@rem End local scope for the variables with windows NT shell
if %ERRORLEVEL% equ 0 goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
set EXIT_CODE=%ERRORLEVEL%
if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -1 +0,0 @@
rootProject.name = "whispercpp"

View File

@ -1,39 +0,0 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Structure;
import com.sun.jna.ptr.PointerByReference;
import io.github.ggerganov.whispercpp.ggml.GgmlType;
import io.github.ggerganov.whispercpp.WhisperModel;
import java.util.List;
public class WhisperContext extends Structure {
int t_load_us = 0;
int t_start_us = 0;
/** weight type (FP32 / FP16 / QX) */
GgmlType wtype = GgmlType.GGML_TYPE_F16;
/** intermediate type (FP32 or FP16) */
GgmlType itype = GgmlType.GGML_TYPE_F16;
// WhisperModel model;
public PointerByReference model;
// whisper_vocab vocab;
// whisper_state * state = nullptr;
public PointerByReference vocab;
public PointerByReference state;
/** populated by whisper_init_from_file() */
String path_model;
// public static class ByReference extends WhisperContext implements Structure.ByReference {
// }
//
// public static class ByValue extends WhisperContext implements Structure.ByValue {
// }
//
// @Override
// protected List<String> getFieldOrder() {
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
// }
}

View File

@ -1,151 +0,0 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
/**
* Before calling most methods, you must call `initContext(modelPath)` to initialise the `ctx` Pointer.
*/
public class WhisperCpp implements AutoCloseable {
private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance;
private Pointer ctx = null;
private Pointer greedyPointer = null;
private Pointer beamPointer = null;
public File modelDir() {
String modelDirPath = System.getenv("XDG_CACHE_HOME");
if (modelDirPath == null) {
modelDirPath = System.getProperty("user.home") + "/.cache";
}
return new File(modelDirPath, "whisper");
}
/**
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
*/
public void initContext(String modelPath) throws FileNotFoundException {
if (ctx != null) {
lib.whisper_free(ctx);
}
if (!modelPath.contains("/") && !modelPath.contains("\\")) {
if (!modelPath.endsWith(".bin")) {
modelPath = "ggml-" + modelPath.replace("-", ".") + ".bin";
}
modelPath = new File(modelDir(), modelPath).getAbsolutePath();
}
ctx = lib.whisper_init_from_file(modelPath);
if (ctx == null) {
throw new FileNotFoundException(modelPath);
}
}
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - GREEDY
*/
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
Pointer pointer;
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) {
if (greedyPointer == null) {
greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = greedyPointer;
} else {
if (beamPointer == null) {
beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal());
}
pointer = beamPointer;
}
WhisperFullParams params = new WhisperFullParams(pointer);
params.read();
return params;
}
@Override
public void close() {
freeContext();
freeParams();
System.out.println("Whisper closed");
}
private void freeContext() {
if (ctx != null) {
lib.whisper_free(ctx);
}
}
private void freeParams() {
if (greedyPointer != null) {
Native.free(Pointer.nativeValue(greedyPointer));
greedyPointer = null;
}
if (beamPointer != null) {
Native.free(Pointer.nativeValue(beamPointer));
beamPointer = null;
}
}
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
if (ctx == null) {
throw new IllegalStateException("Model not initialised");
}
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
throw new IOException("Failed to process audio");
}
int nSegments = lib.whisper_full_n_segments(ctx);
StringBuilder str = new StringBuilder();
for (int i = 0; i < nSegments; i++) {
String text = lib.whisper_full_get_segment_text(ctx, i);
System.out.println("Segment:" + text);
str.append(text);
}
return str.toString().trim();
}
// public int getTextSegmentCount(Pointer ctx) {
// return lib.whisper_full_n_segments(ctx);
// }
// public String getTextSegment(Pointer ctx, int index) {
// return lib.whisper_full_get_segment_text(ctx, index);
// }
public String getSystemInfo() {
return lib.whisper_print_system_info();
}
public int benchMemcpy(int nthread) {
return lib.whisper_bench_memcpy(nthread);
}
public int benchGgmlMulMat(int nthread) {
return lib.whisper_bench_ggml_mul_mat(nthread);
}
}

View File

@ -1,376 +0,0 @@
package io.github.ggerganov.whispercpp;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperModelLoader;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
public interface WhisperCppJnaLibrary extends Library {
WhisperCppJnaLibrary instance = Native.load("whisper", WhisperCppJnaLibrary.class);
String whisper_print_system_info();
/**
* Allocate (almost) all memory needed for the model by loading from a file.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file(String path_model);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer.
*
* @param buffer Model buffer
* @param buffer_size Size of the model buffer
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_buffer(Pointer buffer, int buffer_size);
/**
* Allocate (almost) all memory needed for the model using a model loader.
*
* @param loader Model loader
* @return Whisper context on success, null on failure
*/
Pointer whisper_init(WhisperModelLoader loader);
/**
* Allocate (almost) all memory needed for the model by loading from a file without allocating the state.
*
* @param path_model Path to the model file
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_file_no_state(String path_model);
/**
* Allocate (almost) all memory needed for the model by loading from a buffer without allocating the state.
*
* @param buffer Model buffer
* @param buffer_size Size of the model buffer
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_from_buffer_no_state(Pointer buffer, int buffer_size);
// Pointer whisper_init_from_buffer_no_state(Pointer buffer, long buffer_size);
/**
* Allocate (almost) all memory needed for the model using a model loader without allocating the state.
*
* @param loader Model loader
* @return Whisper context on success, null on failure
*/
Pointer whisper_init_no_state(WhisperModelLoader loader);
/**
* Allocate memory for the Whisper state.
*
* @param ctx Whisper context
* @return Whisper state on success, null on failure
*/
Pointer whisper_init_state(Pointer ctx);
/**
* Free all allocated memory associated with the Whisper context.
*
* @param ctx Whisper context
*/
void whisper_free(Pointer ctx);
/**
* Free all allocated memory associated with the Whisper state.
*
* @param state Whisper state
*/
void whisper_free_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
*
* @param ctx - Pointer to a WhisperContext
* @return 0 on success
*/
int whisper_pcm_to_mel(Pointer ctx, final float[] samples, int n_samples, int n_threads);
/**
* @param ctx Pointer to a WhisperContext
* @param state Pointer to WhisperState
* @param n_samples
* @param n_threads
* @return 0 on success
*/
int whisper_pcm_to_mel_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/**
* This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
* Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
* n_mel must be 80
* @return 0 on success
*/
int whisper_set_mel(Pointer ctx, final float[] data, int n_len, int n_mel);
int whisper_set_mel_with_state(Pointer ctx, Pointer state, final float[] data, int n_len, int n_mel);
/**
* Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context.
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
* Offset can be used to specify the offset of the first frame in the spectrogram.
* @return 0 on success
*/
int whisper_encode(Pointer ctx, int offset, int n_threads);
int whisper_encode_with_state(Pointer ctx, Pointer state, int offset, int n_threads);
/**
* Run the Whisper decoder to obtain the logits and probabilities for the next token.
* Make sure to call whisper_encode() first.
* tokens + n_tokens is the provided context for the decoder.
* n_past is the number of tokens to use from previous decoder calls.
* Returns 0 on success
* TODO: add support for multiple decoders
*/
int whisper_decode(Pointer ctx, Pointer tokens, int n_tokens, int n_past, int n_threads);
/**
* @param ctx
* @param state
* @param tokens Pointer to int tokens
* @param n_tokens
* @param n_past
* @param n_threads
* @return
*/
int whisper_decode_with_state(Pointer ctx, Pointer state, Pointer tokens, int n_tokens, int n_past, int n_threads);
/**
* Convert the provided text into tokens.
* The tokens pointer must be large enough to hold the resulting tokens.
* Returns the number of tokens on success, no more than n_max_tokens
* Returns -1 on failure
* TODO: not sure if correct
*/
int whisper_tokenize(Pointer ctx, String text, Pointer tokens, int n_max_tokens);
/** Largest language id (i.e. number of available languages - 1) */
int whisper_lang_max_id();
/**
* @return the id of the specified language, returns -1 if not found.
* Examples:
* "de" -> 2
* "german" -> 2
*/
int whisper_lang_id(String lang);
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
String whisper_lang_str(int id);
/**
* Use mel data at offset_ms to try and auto-detect the spoken language.
* Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first
* Returns the top language id or negative on failure
* If not null, fills the lang_probs array with the probabilities of all languages
* The array must be whisper_lang_max_id() + 1 in size
*
* ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
*/
int whisper_lang_auto_detect(Pointer ctx, int offset_ms, int n_threads, float[] lang_probs);
int whisper_lang_auto_detect_with_state(Pointer ctx, Pointer state, int offset_ms, int n_threads, float[] lang_probs);
int whisper_n_len (Pointer ctx); // mel length
int whisper_n_len_from_state(Pointer state); // mel length
int whisper_n_vocab (Pointer ctx);
int whisper_n_text_ctx (Pointer ctx);
int whisper_n_audio_ctx (Pointer ctx);
int whisper_is_multilingual (Pointer ctx);
int whisper_model_n_vocab (Pointer ctx);
int whisper_model_n_audio_ctx (Pointer ctx);
int whisper_model_n_audio_state(Pointer ctx);
int whisper_model_n_audio_head (Pointer ctx);
int whisper_model_n_audio_layer(Pointer ctx);
int whisper_model_n_text_ctx (Pointer ctx);
int whisper_model_n_text_state (Pointer ctx);
int whisper_model_n_text_head (Pointer ctx);
int whisper_model_n_text_layer (Pointer ctx);
int whisper_model_n_mels (Pointer ctx);
int whisper_model_ftype (Pointer ctx);
int whisper_model_type (Pointer ctx);
/**
* Token logits obtained from the last call to whisper_decode().
* The logits for the last token are stored in the last row
* Rows: n_tokens
* Cols: n_vocab
*/
float[] whisper_get_logits (Pointer ctx);
float[] whisper_get_logits_from_state(Pointer state);
// Token Id -> String. Uses the vocabulary in the provided context
String whisper_token_to_str(Pointer ctx, int token);
String whisper_model_type_readable(Pointer ctx);
// Special tokens
int whisper_token_eot (Pointer ctx);
int whisper_token_sot (Pointer ctx);
int whisper_token_prev(Pointer ctx);
int whisper_token_solm(Pointer ctx);
int whisper_token_not (Pointer ctx);
int whisper_token_beg (Pointer ctx);
int whisper_token_lang(Pointer ctx, int lang_id);
// Task tokens
int whisper_token_translate (Pointer ctx);
int whisper_token_transcribe(Pointer ctx);
// Performance information from the default state.
void whisper_print_timings(Pointer ctx);
void whisper_reset_timings(Pointer ctx);
// Note: Even if `whisper_full_params is stripped back to just 4 ints, JNA throws "Invalid memory access"
// when `whisper_full_default_params()` tries to return a struct.
// WhisperFullParams whisper_full_default_params(int strategy);
/**
* Provides default params which can be used with `whisper_full()` etc.
* Because this function allocates memory for the params, the caller must call either:
* - call `whisper_free_params()`
* - `Native.free(Pointer.nativeValue(pointer));`
*
* @param strategy - WhisperSamplingStrategy.value
*/
Pointer whisper_full_default_params_by_ref(int strategy);
void whisper_free_params(Pointer params);
/**
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*/
int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
// Result is stored in the default state of the context
// Not thread safe if executed in parallel on the same context.
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
/**
* Number of generated text segments.
* A segment can be a few words, a sentence, or even a paragraph.
* @param ctx Pointer to WhisperContext
*/
int whisper_full_n_segments (Pointer ctx);
/**
* @param state Pointer to WhisperState
*/
int whisper_full_n_segments_from_state(Pointer state);
/**
* Language id associated with the context's default state.
* @param ctx Pointer to WhisperContext
*/
int whisper_full_lang_id(Pointer ctx);
/** Language id associated with the provided state */
int whisper_full_lang_id_from_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
* @return 0 on success
*/
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/** Get the start time of the specified segment. */
long whisper_full_get_segment_t0(Pointer ctx, int i_segment);
/** Get the start time of the specified segment from the state. */
long whisper_full_get_segment_t0_from_state(Pointer state, int i_segment);
/** Get the end time of the specified segment. */
long whisper_full_get_segment_t1(Pointer ctx, int i_segment);
/** Get the end time of the specified segment from the state. */
long whisper_full_get_segment_t1_from_state(Pointer state, int i_segment);
/** Get the text of the specified segment. */
String whisper_full_get_segment_text(Pointer ctx, int i_segment);
/** Get the text of the specified segment from the state. */
String whisper_full_get_segment_text_from_state(Pointer state, int i_segment);
/** Get the number of tokens in the specified segment. */
int whisper_full_n_tokens(Pointer ctx, int i_segment);
/** Get the number of tokens in the specified segment from the state. */
int whisper_full_n_tokens_from_state(Pointer state, int i_segment);
/** Get the token text of the specified token in the specified segment. */
String whisper_full_get_token_text(Pointer ctx, int i_segment, int i_token);
/** Get the token text of the specified token in the specified segment from the state. */
String whisper_full_get_token_text_from_state(Pointer ctx, Pointer state, int i_segment, int i_token);
/** Get the token ID of the specified token in the specified segment. */
int whisper_full_get_token_id(Pointer ctx, int i_segment, int i_token);
/** Get the token ID of the specified token in the specified segment from the state. */
int whisper_full_get_token_id_from_state(Pointer state, int i_segment, int i_token);
/** Get token data for the specified token in the specified segment. */
WhisperTokenData whisper_full_get_token_data(Pointer ctx, int i_segment, int i_token);
/** Get token data for the specified token in the specified segment from the state. */
WhisperTokenData whisper_full_get_token_data_from_state(Pointer state, int i_segment, int i_token);
/** Get the probability of the specified token in the specified segment. */
float whisper_full_get_token_p(Pointer ctx, int i_segment, int i_token);
/** Get the probability of the specified token in the specified segment from the state. */
float whisper_full_get_token_p_from_state(Pointer state, int i_segment, int i_token);
/**
* Benchmark function for memcpy.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark.
*/
int whisper_bench_memcpy(int nThreads);
/**
* Benchmark function for memcpy as a string.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark as a string.
*/
String whisper_bench_memcpy_str(int nThreads);
/**
* Benchmark function for ggml_mul_mat.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark.
*/
int whisper_bench_ggml_mul_mat(int nThreads);
/**
* Benchmark function for ggml_mul_mat as a string.
*
* @param nThreads Number of threads to use for the benchmark.
* @return The result of the benchmark as a string.
*/
String whisper_bench_ggml_mul_mat_str(int nThreads);
}

View File

@ -1,24 +0,0 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback before the encoder starts.
* If not null, called before the encoder starts.
* If it returns false, the computation is aborted.
*/
public interface WhisperEncoderBeginCallback extends Callback {
/**
* Callback method before the encoder starts.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param user_data User data.
* @return True if the computation should proceed, false otherwise.
*/
boolean callback(Pointer ctx, Pointer state, Pointer user_data);
}

View File

@ -1,25 +0,0 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.model.WhisperTokenData;
/**
* Callback to filter logits.
* Can be used to modify the logits before sampling.
* If not null, called after applying temperature to logits.
*/
public interface WhisperLogitsFilterCallback extends Callback {
/**
* Callback method to filter logits.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param tokens The array of whisper_token_data.
* @param n_tokens The number of tokens.
* @param logits The array of logits.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, WhisperTokenData[] tokens, int n_tokens, float[] logits, Pointer user_data);
}

View File

@ -1,24 +0,0 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback for the text segment.
* Called on every newly generated text segment.
* Use the whisper_full_...() functions to obtain the text segments.
*/
public interface WhisperNewSegmentCallback extends Callback {
/**
* Callback method for the text segment.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param n_new The number of newly generated text segments.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, int n_new, Pointer user_data);
}

View File

@ -1,22 +0,0 @@
package io.github.ggerganov.whispercpp.callbacks;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import io.github.ggerganov.whispercpp.WhisperContext;
import io.github.ggerganov.whispercpp.model.WhisperState;
/**
* Callback for progress updates.
*/
public interface WhisperProgressCallback extends Callback {
/**
* Callback method for progress updates.
*
* @param ctx The whisper context.
* @param state The whisper state.
* @param progress The progress value.
* @param user_data User data.
*/
void callback(Pointer ctx, Pointer state, int progress, Pointer user_data);
}

View File

@ -1,4 +0,0 @@
package io.github.ggerganov.whispercpp.ggml;
public class GgmlTensor {
}

View File

@ -1,18 +0,0 @@
package io.github.ggerganov.whispercpp.ggml;
public enum GgmlType {
GGML_TYPE_F32,
GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
REMOVED_GGML_TYPE_Q4_2, // support has been removed
REMOVED_GGML_TYPE_Q4_3, // support has been removed
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q8_1,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
GGML_TYPE_COUNT,
}

View File

@ -1,10 +0,0 @@
package io.github.ggerganov.whispercpp.model;
public enum EModel {
MODEL_UNKNOWN,
MODEL_TINY,
MODEL_BASE,
MODEL_SMALL,
MODEL_MEDIUM,
MODEL_LARGE,
}

View File

@ -1,49 +0,0 @@
package io.github.ggerganov.whispercpp;
import io.github.ggerganov.whispercpp.ggml.GgmlTensor;
import io.github.ggerganov.whispercpp.model.EModel;
public class WhisperModel {
// EModel type = EModel.MODEL_UNKNOWN;
//
// WhisperHParams hparams;
// WhisperFilters filters;
//
// // encoder.positional_embedding
// GgmlTensor e_pe;
//
// // encoder.conv1
// GgmlTensor e_conv_1_w;
// GgmlTensor e_conv_1_b;
//
// // encoder.conv2
// GgmlTensor e_conv_2_w;
// GgmlTensor e_conv_2_b;
//
// // encoder.ln_post
// GgmlTensor e_ln_w;
// GgmlTensor e_ln_b;
//
// // decoder.positional_embedding
// GgmlTensor d_pe;
//
// // decoder.token_embedding
// GgmlTensor d_te;
//
// // decoder.ln
// GgmlTensor d_ln_w;
// GgmlTensor d_ln_b;
//
// std::vector<whisper_layer_encoder> layers_encoder;
// std::vector<whisper_layer_decoder> layers_decoder;
//
// // context
// struct ggml_context * ctx;
//
// // the model memory buffer is read-only and can be shared between processors
// std::vector<uint8_t> * buf;
//
// // tensors
// int n_loaded;
// Map<String, GgmlTensor> tensors;
}

View File

@ -1,62 +0,0 @@
package io.github.ggerganov.whispercpp.model;
import com.sun.jna.Callback;
import com.sun.jna.Pointer;
import com.sun.jna.Structure;
public class WhisperModelLoader extends Structure {
public Pointer context;
public ReadFunction read;
public EOFFunction eof;
public CloseFunction close;
public static class ReadFunction implements Callback {
public Pointer invoke(Pointer ctx, Pointer output, int readSize) {
// TODO
return ctx;
}
}
public static class EOFFunction implements Callback {
public boolean invoke(Pointer ctx) {
// TODO
return false;
}
}
public static class CloseFunction implements Callback {
public void invoke(Pointer ctx) {
// TODO
}
}
// public WhisperModelLoader(Pointer p) {
// super(p);
// read = new ReadFunction();
// eof = new EOFFunction();
// close = new CloseFunction();
// read.setCallback(this);
// eof.setCallback(this);
// close.setCallback(this);
// read.write();
// eof.write();
// close.write();
// }
public WhisperModelLoader() {
super();
}
public interface ReadCallback extends Callback {
Pointer invoke(Pointer ctx, Pointer output, int readSize);
}
public interface EOFCallback extends Callback {
boolean invoke(Pointer ctx);
}
public interface CloseCallback extends Callback {
void invoke(Pointer ctx);
}
}

View File

@ -1,4 +0,0 @@
package io.github.ggerganov.whispercpp.model;
public class WhisperState {
}

View File

@ -1,50 +0,0 @@
package io.github.ggerganov.whispercpp.model;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
/**
* Structure representing token data.
*/
public class WhisperTokenData extends Structure {
/** Token ID. */
public int id;
/** Forced timestamp token ID. */
public int tid;
/** Probability of the token. */
public float p;
/** Log probability of the token. */
public float plog;
/** Probability of the timestamp token. */
public float pt;
/** Sum of probabilities of all timestamp tokens. */
public float ptsum;
/**
* Start time of the token (token-level timestamp data).
* Do not use if you haven't computed token-level timestamps.
*/
public long t0;
/**
* End time of the token (token-level timestamp data).
* Do not use if you haven't computed token-level timestamps.
*/
public long t1;
/** Voice length of the token. */
public float vlen;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("id", "tid", "p", "plog", "pt", "ptsum", "t0", "t1", "vlen");
}
}

View File

@ -1,19 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Arrays;
import java.util.List;
public class BeamSearchParams extends Structure {
/** ref: <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265">...</a> */
public int beam_size;
/** ref: <a href="https://arxiv.org/pdf/2204.05424.pdf">...</a> */
public float patience;
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("beam_size", "patience");
}
}

View File

@ -1,30 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.IntegerType;
import java.util.function.BooleanSupplier;
public class CBool extends IntegerType implements BooleanSupplier {
public static final int SIZE = 1;
public static final CBool FALSE = new CBool(0);
public static final CBool TRUE = new CBool(1);
public CBool() {
this(0);
}
public CBool(long value) {
super(SIZE, value, true);
}
@Override
public boolean getAsBoolean() {
return intValue() == 1;
}
@Override
public String toString() {
return intValue() == 1 ? "true" : "false";
}
}

View File

@ -1,16 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.Structure;
import java.util.Collections;
import java.util.List;
public class GreedyParams extends Structure {
/** <a href="https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264">...</a> */
public int best_of;
@Override
protected List<String> getFieldOrder() {
return Collections.singletonList("best_of");
}
}

View File

@ -1,10 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import java.util.List;
public class WhisperFilters {
int n_mel;
int n_fft;
List<Float> data;
}

View File

@ -1,321 +0,0 @@
package io.github.ggerganov.whispercpp.params;
import com.sun.jna.*;
import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
import java.util.Arrays;
import java.util.List;
/**
* Parameters for the whisper_full() function.
* If you change the order or add new parameters, make sure to update the default values in whisper.cpp:
* whisper_full_default_params()
*/
public class WhisperFullParams extends Structure {
public WhisperFullParams(Pointer p) {
super(p);
// super(p, ALIGN_MSVC);
// super(p, ALIGN_GNUC);
}
/** Sampling strategy for whisper_full() function. */
public int strategy;
/** Number of threads. (default = 4) */
public int n_threads;
/** Maximum tokens to use from past text as a prompt for the decoder. (default = 16384) */
public int n_max_text_ctx;
/** Start offset in milliseconds. (default = 0) */
public int offset_ms;
/** Audio duration to process in milliseconds. (default = 0) */
public int duration_ms;
/** Translate flag. (default = false) */
public CBool translate;
/** The compliment of translateMode() */
public void transcribeMode() {
translate = CBool.FALSE;
}
/** The compliment of transcribeMode() */
public void translateMode() {
translate = CBool.TRUE;
}
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public CBool no_context;
/** Flag to indicate whether to use past transcription (if any) as an initial prompt for the decoder. (default = true) */
public void enableContext(boolean enable) {
no_context = enable ? CBool.FALSE : CBool.TRUE;
}
/** Flag to force single segment output (useful for streaming). (default = false) */
public CBool single_segment;
/** Flag to force single segment output (useful for streaming). (default = false) */
public void singleSegment(boolean single) {
single_segment = single ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public CBool print_special;
/** Flag to print special tokens (e.g., &lt;SOT>, &lt;EOT>, &lt;BEG>, etc.). (default = false) */
public void printSpecial(boolean enable) {
print_special = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print progress information. (default = true) */
public CBool print_progress;
/** Flag to print progress information. (default = true) */
public void printProgress(boolean enable) {
print_progress = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public CBool print_realtime;
/** Flag to print results from within whisper.cpp (avoid it, use callback instead). (default = true) */
public void printRealtime(boolean enable) {
print_realtime = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public CBool print_timestamps;
/** Flag to print timestamps for each text segment when printing realtime. (default = true) */
public void printTimestamps(boolean enable) {
print_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public CBool token_timestamps;
/** [EXPERIMENTAL] Flag to enable token-level timestamps. (default = false) */
public void tokenTimestamps(boolean enable) {
token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
}
/** [EXPERIMENTAL] Timestamp token probability threshold (~0.01). (default = 0.01) */
public float thold_pt;
/** [EXPERIMENTAL] Timestamp token sum probability threshold (~0.01). */
public float thold_ptsum;
/** Maximum segment length in characters. (default = 0) */
public int max_len;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public CBool split_on_word;
/** Flag to split on word rather than on token (when used with max_len). (default = false) */
public void splitOnWord(boolean enable) {
split_on_word = enable ? CBool.TRUE : CBool.FALSE;
}
/** Maximum tokens per segment (0, default = no limit) */
public int max_tokens;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public CBool speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public void speedUp(boolean enable) {
speed_up = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */
public int audio_ctx;
/** Enable tinydiarize (default = false) */
public CBool tdrz_enable;
/** Enable tinydiarize (default = false) */
public void tdrzEnable(boolean enable) {
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
}
/** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */
public String initial_prompt;
/** Prompt tokens. (int*) */
public Pointer prompt_tokens;
public void setPromptTokens(int[] tokens) {
Memory mem = new Memory(tokens.length * 4L);
mem.write(0, tokens, 0, tokens.length);
prompt_tokens = mem;
}
/** Number of prompt tokens. */
public int prompt_n_tokens;
/** Language for auto-detection.
* For auto-detection, set to `null`, `""`, or "auto". */
public String language;
/** Flag to indicate whether to detect language automatically. */
public CBool detect_language;
/** Flag to indicate whether to detect language automatically. */
public void detectLanguage(boolean enable) {
detect_language = enable ? CBool.TRUE : CBool.FALSE;
}
// Common decoding parameters.
/** Flag to suppress blank tokens. */
public CBool suppress_blank;
public void suppressBlanks(boolean enable) {
suppress_blank = enable ? CBool.TRUE : CBool.FALSE;
}
/** Flag to suppress non-speech tokens. */
public CBool suppress_non_speech_tokens;
/** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) {
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE;
}
/** Initial decoding temperature. */
public float temperature;
/** Maximum initial timestamp. */
public float max_initial_ts;
/** Length penalty. */
public float length_penalty;
// Fallback parameters.
/** Temperature increment. */
public float temperature_inc;
/** Entropy threshold (similar to OpenAI's "compression_ratio_threshold"). */
public float entropy_thold;
/** Log probability threshold. */
public float logprob_thold;
/** No speech threshold. */
public float no_speech_thold;
/** Greedy decoding parameters. */
public GreedyParams greedy;
/**
* Beam search decoding parameters.
*/
public BeamSearchParams beam_search;
public void setBestOf(int bestOf) {
if (greedy == null) {
greedy = new GreedyParams();
}
greedy.best_of = bestOf;
}
public void setBeamSize(int beamSize) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
}
public void setBeamSizeAndPatience(int beamSize, float patience) {
if (beam_search == null) {
beam_search = new BeamSearchParams();
}
beam_search.beam_size = beamSize;
beam_search.patience = patience;
}
/**
* Callback for every newly generated text segment.
* WhisperNewSegmentCallback
*/
public Pointer new_segment_callback;
/**
* User data for the new_segment_callback.
*/
public Pointer new_segment_callback_user_data;
/**
* Callback on each progress update.
* WhisperProgressCallback
*/
public Pointer progress_callback;
/**
* User data for the progress_callback.
*/
public Pointer progress_callback_user_data;
/**
* Callback each time before the encoder starts.
* WhisperEncoderBeginCallback
*/
public Pointer encoder_begin_callback;
/**
* User data for the encoder_begin_callback.
*/
public Pointer encoder_begin_callback_user_data;
/**
* Callback by each decoder to filter obtained logits.
* WhisperLogitsFilterCallback
*/
public Pointer logits_filter_callback;
/**
* User data for the logits_filter_callback.
*/
public Pointer logits_filter_callback_user_data;
public void setNewSegmentCallback(WhisperNewSegmentCallback callback) {
new_segment_callback = CallbackReference.getFunctionPointer(callback);
}
public void setProgressCallback(WhisperProgressCallback callback) {
progress_callback = CallbackReference.getFunctionPointer(callback);
}
public void setEncoderBeginCallbackeginCallbackCallback(WhisperEncoderBeginCallback callback) {
encoder_begin_callback = CallbackReference.getFunctionPointer(callback);
}
public void setLogitsFilterCallback(WhisperLogitsFilterCallback callback) {
logits_filter_callback = CallbackReference.getFunctionPointer(callback);
}
@Override
protected List<String> getFieldOrder() {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data",
"encoder_begin_callback", "encoder_begin_callback_user_data",
"logits_filter_callback", "logits_filter_callback_user_data");
}
}

View File

@ -1,15 +0,0 @@
package io.github.ggerganov.whispercpp.params;
public class WhisperHParams {
int n_vocab = 51864;
int n_audio_ctx = 1500;
int n_audio_state = 384;
int n_audio_head = 6;
int n_audio_layer = 4;
int n_text_ctx = 448;
int n_text_state = 384;
int n_text_head = 6;
int n_text_layer = 4;
int n_mels = 80;
int ftype = 1;
}

View File

@ -1,10 +0,0 @@
package io.github.ggerganov.whispercpp.params;
/** Available sampling strategies */
public enum WhisperSamplingStrategy {
/** similar to OpenAI's GreedyDecoder */
WHISPER_SAMPLING_GREEDY,
/** similar to OpenAI's BeamSearchDecoder */
WHISPER_SAMPLING_BEAM_SEARCH
}

View File

@ -1,102 +0,0 @@
package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import io.github.ggerganov.whispercpp.params.CBool;
import io.github.ggerganov.whispercpp.params.WhisperFullParams;
import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.File;
import java.io.FileNotFoundException;
class WhisperCppTest {
private static WhisperCpp whisper = new WhisperCpp();
private static boolean modelInitialised = false;
@BeforeAll
static void init() throws FileNotFoundException {
// By default, models are loaded from ~/.cache/whisper/ and are usually named "ggml-${name}.bin"
// or you can provide the absolute path to the model file.
String modelName = "../../models/ggml-tiny.en.bin";
try {
whisper.initContext(modelName);
// whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// whisper.getJavaDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
modelInitialised = true;
} catch (FileNotFoundException ex) {
System.out.println("Model " + modelName + " not found");
}
}
@Test
void testGetDefaultFullParams_BeamSearch() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertFalse(params.translate);
assertEquals(0.01f, params.thold_pt);
assertEquals(2, params.beam_search.beam_size);
assertEquals(-1.0f, params.beam_search.patience);
}
@Test
void testGetDefaultFullParams_Greedy() {
// When
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
// Then
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertEquals(2, params.greedy.best_of);
}
@Test
void testFullTranscribe() throws Exception {
if (!modelInitialised) {
System.out.println("Model not initialised, skipping test");
return;
}
// Given
File file = new File(System.getProperty("user.dir"), "../../samples/jfk.wav");
AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(file);
byte[] b = new byte[audioInputStream.available()];
float[] floats = new float[b.length / 2];
// WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
params.print_progress = CBool.FALSE;
// params.initial_prompt = "and so my fellow Americans um, like";
try {
audioInputStream.read(b);
for (int i = 0, j = 0; i < b.length; i += 2, j++) {
int intSample = (int) (b[i + 1]) << 8 | (int) (b[i]) & 0xFF;
floats[j] = intSample / 32767.0f;
}
// When
String result = whisper.fullTranscribe(params, floats);
// Then
System.err.println(result);
assertEquals("And so my fellow Americans ask not what your country can do for you " +
"ask what you can do for your country.",
result.replace(",", ""));
} finally {
audioInputStream.close();
}
}
}

View File

@ -1,17 +0,0 @@
package io.github.ggerganov.whispercpp;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;
class WhisperJnaLibraryTest {
@Test
void testWhisperPrint_system_info() {
String systemInfo = WhisperCppJnaLibrary.instance.whisper_print_system_info();
// eg: "AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0
// | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | COREML = 0 | "
System.out.println("System info: " + systemInfo);
assertTrue(systemInfo.length() > 10);
}
}

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.4.2",
"version": "1.4.1",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

5
contrib/debian/control Normal file
View File

@ -0,0 +1,5 @@
Package: whisper-small-cpp
Architecture: amd64
Maintainer: Alexey Kharlamov <alexey@kharlamov.biz>
Description: Whisper Speech to Text Converter
Depends: libc6 (>= 2.2.1), intel-mkl

View File

@ -1,9 +1,5 @@
#if !__has_feature(objc_arc)
#error This file must be compiled with automatic reference counting enabled (-fobjc-arc)
#endif
#import "whisper-encoder.h"
#import "whisper-encoder-impl.h"
#import "coreml/whisper-encoder.h"
#import "coreml/whisper-encoder-impl.h"
#import <CoreML/CoreML.h>
@ -53,11 +49,17 @@ void whisper_coreml_encode(
error: nil
];
@autoreleasepool {
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
whisper_encoder_implOutput * outCoreML = [(__bridge id) ctx->data predictionFromLogmel_data:inMultiArray error:nil];
memcpy(out, outCoreML.output.dataPointer, outCoreML.output.count * sizeof(float));
}
MLMultiArray * outMA = outCoreML.output;
//NSArray<NSNumber *> * shape = outMA.shape;
//NSArray<NSNumber *> * strides = outMA.strides;
//printf("shape: %ld %ld %ld %ld\n", [shape[0] longValue], [shape[1] longValue], [shape[2] longValue], [shape[3] longValue]);
//printf("strides: %ld %ld %ld %ld\n", [strides[0] longValue], [strides[1] longValue], [strides[2] longValue], [strides[3] longValue]);
memcpy(out, outMA.dataPointer, outMA.count * sizeof(float));
}
#if __cplusplus

View File

@ -69,5 +69,4 @@ else()
add_subdirectory(quantize)
add_subdirectory(talk)
add_subdirectory(talk-llama)
add_subdirectory(lsp)
endif()

View File

@ -6,6 +6,7 @@
static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q4_0", GGML_FTYPE_MOSTLY_Q4_0},
{"q4_1", GGML_FTYPE_MOSTLY_Q4_1},
{"q4_2", GGML_FTYPE_MOSTLY_Q4_2},
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
@ -45,6 +46,7 @@ bool ggml_common_quantize_0(
switch (ftype) {
case GGML_FTYPE_MOSTLY_Q4_0: qtype = GGML_TYPE_Q4_0; break;
case GGML_FTYPE_MOSTLY_Q4_1: qtype = GGML_TYPE_Q4_1; break;
case GGML_FTYPE_MOSTLY_Q4_2: qtype = GGML_TYPE_Q4_2; break;
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
@ -52,11 +54,6 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
case GGML_FTYPE_MOSTLY_Q2_K:
case GGML_FTYPE_MOSTLY_Q3_K:
case GGML_FTYPE_MOSTLY_Q4_K:
case GGML_FTYPE_MOSTLY_Q5_K:
case GGML_FTYPE_MOSTLY_Q6_K:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
@ -174,6 +171,10 @@ bool ggml_common_quantize_0(
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_2:
{
cur_size = ggml_quantize_q4_2(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_0:
{
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
@ -192,12 +193,6 @@ bool ggml_common_quantize_0(
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_COUNT:
{
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));

View File

@ -6,21 +6,13 @@
#include "dr_wav.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <regex>
#include <locale>
#include <codecvt>
#include <sstream>
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
@ -34,45 +26,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") {
params.top_k = std::max(1, std::stoi(argv[++i]));
params.top_k = std::stoi(argv[++i]);
} else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]);
} else if (arg == "--temp") {
params.temp = std::stof(argv[++i]);
} else if (arg == "--repeat-last-n") {
params.repeat_last_n = std::stof(argv[++i]);
} else if (arg == "--repeat-penalty") {
params.repeat_penalty = std::stof(argv[++i]);
} else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "-ip" || arg == "--interactive-port") {
params.interactive = true;
params.interactive_port = std::stoi(argv[++i]);
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
} else if (arg == "-f" || arg == "--file") {
if (++i > argc) {
fprintf(stderr, "Invalid file param");
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
break;
}
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
} else if (arg == "-tt" || arg == "--token_test") {
params.token_test = argv[++i];
}
else {
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
exit(0);
@ -91,16 +57,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " load prompt from a file\n");
fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n");
fprintf(stderr, " test tokenization\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
@ -141,10 +101,6 @@ std::string replace(const std::string & s, const std::string & from, const std::
return result;
}
void gpt_vocab::add_special_token(const std::string & token) {
special_tokens.push_back(token);
}
std::map<std::string, int32_t> json_parse(const std::string & fname) {
std::map<std::string, int32_t> result;
@ -236,82 +192,54 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result;
}
std::string convert_to_utf8(const std::wstring & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(input);
}
std::wstring convert_to_wstring(const std::string & input) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.from_bytes(input);
}
void gpt_split_words(std::string str, std::vector<std::string>& words) {
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
const std::regex re(pattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
std::string special_tokens_subpattern;
for (const auto & token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
}
std::regex re(pat);
std::smatch m;
std::regex re(special_tokens_subpattern);
std::smatch m;
// Split the text by special tokens.
while (std::regex_search(str, m, re)) {
// Split the substrings in-between special tokens into words.
gpt_split_words(m.prefix(), words);
// Add matched special tokens as words.
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
// Remaining text without special tokens will be handled below.
str = m.suffix();
}
gpt_split_words(str, words);
}
// find the longest token that forms each word in words:
// find the longest tokens that form the words:
std::vector<gpt_vocab::id> tokens;
for (const auto & word : words) {
for (int i = 0; i < (int) word.size(); ){
for (int j = word.size() - 1; j >= i; j--){
auto cand = word.substr(i, j-i+1);
auto it = vocab.token_to_id.find(cand);
if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j + 1;
i = j;
break;
}
else if (j == i){ // word.substr(i, 1) has no matching
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
i++;
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
@ -319,70 +247,6 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
return tokens;
}
std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {
std::vector<gpt_vocab::id> output;
std::stringstream ss(input);
std::string token;
while (std::getline(ss, token, delimiter)) {
output.push_back(std::stoi(token));
}
return output;
}
std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){
if (fpath_test.empty()){
fprintf(stderr, "%s : No test file found.\n", __func__);
return std::map<std::string, std::vector<gpt_vocab::id>>();
}
std::map<std::string, std::vector<gpt_vocab::id>> tests;
auto fin = std::ifstream(fpath_test, std::ios_base::in);
const char * delimeter = " => ";
const char del_tok = ',';
std::string line;
while (std::getline(fin, line)) {
size_t delimiterPos = line.find(delimeter);
if (delimiterPos != std::string::npos) {
std::string text = line.substr(0, delimiterPos);
std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));
tests[text] = parse_tokens_from_string(s_tokens, del_tok);
}
}
return tests;
}
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){
std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);
size_t n_fails = 0;
for (const auto & test : tests) {
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);
if (tokens != test.second){
n_fails++;
// print out failure cases
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str());
fprintf(stderr, "%s : tokens in hf: ", __func__);
for (const auto & t : test.second) {
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
}
fprintf(stderr, "\n");
fprintf(stderr, "%s : tokens in ggml: ", __func__);
for (const auto & t : tokens) {
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
}
fprintf(stderr, "\n");
}
}
fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size());
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
@ -482,122 +346,6 @@ gpt_vocab::id gpt_sample_top_k_top_p(
return logits_id[idx].second;
}
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
size_t last_n_tokens_data_size,
int top_k,
double top_p,
double temp,
int repeat_last_n,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
const auto * plogits = logits;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
gpt_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
// printf("\n");
// for (int i = 0; i < (int) probs.size(); i++) {
// for (int i = 0; i < 10; i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
// }
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin

View File

@ -15,24 +15,19 @@
//
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t seed = -1; // RNG seed
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_predict = 200; // new tokens to predict
int32_t n_batch = 8; // batch size for prompt processing
// sampling parameters
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
int32_t repeat_last_n = 64;
float repeat_penalty = 1.00f;
int32_t top_k = 40;
float top_p = 0.9f;
float temp = 0.9f;
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt = "";
std::string token_test = "";
int32_t n_batch = 8; // batch size for prompt processing
bool interactive = false;
int32_t interactive_port = -1;
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
@ -58,20 +53,11 @@ struct gpt_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string & token);
};
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname);
std::string convert_to_utf8(const std::wstring & input);
std::wstring convert_to_wstring(const std::string & input);
void gpt_split_words(std::string str, std::vector<std::string>& words);
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
@ -84,15 +70,6 @@ void gpt_split_words(std::string str, std::vector<std::string>& words);
//
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
// test outputs of gpt_tokenize
//
// - compare with tokens generated by the huggingface tokenizer
// - test cases are chosen based on the model's main language (under 'prompt' directory)
// - if all sentences are tokenized identically, print 'All tests passed.'
// - otherwise, print sentence, huggingface tokens, ggml tokens
//
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test);
// load the tokens from encoder.json
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
@ -112,18 +89,6 @@ gpt_vocab::id gpt_sample_top_k_top_p(
double temp,
std::mt19937 & rng);
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
const gpt_vocab & vocab,
const float * logits,
const int32_t * last_n_tokens_data,
size_t last_n_tokens_data_size,
int top_k,
double top_p,
double temp,
int repeat_last_n,
float repeat_penalty,
std::mt19937 & rng);
//
// Audio utils
//

View File

@ -1,9 +0,0 @@
if (WHISPER_SDL2)
# stream
set(TARGET lsp)
add_executable(${TARGET} lsp.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -1,104 +0,0 @@
# Language Server
This example consists of a simple language server to expose both unguided
and guided (command) transcriptions by sending json messages over stdout/stdin
as well as a rather robust vim plugin that makes use of the language server.
## Vim plugin quick start
Compile the language server with
```bash
make lsp
```
Install the plugin itself by copying or symlinking whisper.vim into ~/.vim/autoload/
In your vimrc, set the path of your whisper.cpp directory and optionally add some keybinds.
```vim
let g:whisper_dir = "~/whisper.cpp"
" Start listening for commands when Ctrl - g is pressed in normal mode
nnoremap <C-G> call whisper#requestCommands()<CR>
" Start unguided transcription when Ctrl - g is pressed in insert mode
inoremap <C-G> <Cmd>call whisper#doTranscription()<CR>
```
## Vim plugin usage
The vim plugin was designed to closely follow the mnemonics of vim
`s:spoken_dict` is used to translate keys to their spoken form.
Keys corresponding to a string use that spoken value normally and when a motion is expected, but use the key itself when a character is expected.
Keys corresponding to a dict, like `i`, can have manual difinitions given to each possible commandset.
0 is normal (insert), 1 is motion (inside), 2 is it's usage as a single key ([till] i), and 3 is it's usage in an area selection (s -> [around] sentence)
Some punctuation items, like `-` are explicitly given pronunciations to prevent them from being picked as punctuation instead of an actual command word.
Not all commands will tokenize to a single token and this can interfere with interpretation. "yank" as an example, takes multiple tokens and correspondingly, will give more accurate detection when only the first "ya" is used. While it could be changed to something else that is a single token (copy), value was placed on maintaining vim mnemonics.
Commands that would normally move the editor into insert mode (insert, append, open, change) will begin unguided transcription.
Unguided transcription will end when a speech segment ends in exit.
Presence of punctuation can be designated by whether or not you add a pause between the previous speech segment and exit.
Exiting only occurs if exit is the last word, so "Take the first exit on your right" would not cause transcription to end.
After a command is evaluated, the plugin will continue listening for the next command.
While in command mode, "Exit" will end listening.
A best effort approach is taken to keep track of audio that is recorded while a previous chunk is still processing and immediately interpret it afterwards, but the current voice detection still needs a fairly sizable gap to determine when a command has been spoken.
Log information is sent to a special `whisper_log` buffer and can be accessed with
```vim
:e whisper_log
```
## Vim plugin configuration
`g:whisper_dir`
A full path to the whisper.cpp repo. It can be expanded in the definition like so:
```vim
let g:whisper_dir = expand("~/whisper.cpp/")
```
(The WHISPER_CPP_HOME environment variable is also checked for users of the existing whisper.nvim script)
`g:whisper_lsp_path`
Can be used to manually set the path to the language server.
If not defined, it will be inferred from the above whisper_dir
`g:whisper_model_path`
A full path to the model to load. If not defined, it will default to ggml-base.en.bin
`g:whisper_user_commands`
A dictionary of spoken commands that correspond to either strings or funcrefs.
This can be used to create connections with other user plugins, for example
```vim
let g:whisper_user_commands = {"gen": "llama#doLlamaGen"}
```
will trigger the llama.cpp plugin to begin generation when "gen" is spoken
## Language server methods
`registerCommandset`
`params` is a list of strings that should be checked for with this commandset. The server prepends a space to these strings before tokenizing.
Responds with
`result.index` an integer index for the commandset registered, which should be included when initiating a guided transcription to select this commandset.
Will return an error if any of the commands in the commandset have duplicate tokenizations
`guided`
`params.commandset_index` An index returned by a corresponding commandset registration. If not set, the most recently registered commandset is used.
`params.timestamp` A positive unsigned integer which designates a point in time which audio should begin processing from. If left blank, the start point of audio processing will be the moment the message is recieved. This should be left blank unless you have a timestamp from a previous response.
Responds with
`result.command_index` The numerical index (starting from 0) of the detected command in the selected commandset
`result.command_text` A string containing the command as provided in the commandset
`result.timestamp` A positive unsigned integer that designates the point in time which audio stopped being processed at. Pass this timestamp back in a subsequent message to mask the latency of transcription.
`unguided`
`params.no_context` Sets the corresponding whisper `no_context` param. Defaults to true. Might provide more accurate results for consecutive unguided transcriptions if those after the first are set to false.
`params.prompt` If provided, sets the initial prompt used during transcription.
`params.timestamp` A positive unsigned integer which designates a point in time which audio should begin processing from. If left blank, the start point of audio processing will be the moment the message is recieved. This should be left blank unless you have a timestamp from a previous response.
Responds with
`result.transcription` A string containing the transcribed text. N.B. This will almost always start with a space due to how text is tokenized.
`result.timestamp` A positive unsigned integer that designates the point in time which audio stopped being processed at. Pass this timestamp back in a subsequent message to mask the latency of transcription.

File diff suppressed because it is too large Load Diff

View File

@ -1,458 +0,0 @@
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "json.hpp"
#include <iostream>
#include <cassert>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <deque>
#include <set>
using json = nlohmann::json;
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t prompt_ms = 5000;
int32_t command_ms = 8000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
};
struct command {
std::vector<whisper_token> tokens;
std::string plaintext;
};
struct commandset {
std::vector<struct command> commands;
std::vector<whisper_token> prompt_tokens;
// TODO: Store longest command?
// Multi-token commands should have probabilities of subsequent logits
// given that the prior logit is correct.
// In this case, all commands must be iterated.
// This however, is likely highly involved as different tokens
// almost certainly have different spoken lengths
// It would also have performance implications equivalent to a beam search
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, "\n");
}
uint64_t wait_for_vad(audio_async & audio, json jparams, const whisper_params & params, uint64_t maxlength_ms, std::vector<float> & pcmf32) {
using namespace std::chrono;
uint64_t time_now = time_point_cast<milliseconds>(system_clock::now()).time_since_epoch().count();
uint64_t start_time = time_now;
if (jparams.contains("timestamp")) {
start_time = jparams.at("timestamp");
}
if(time_now - start_time < 500) {
//wait for a backlog of audio
std::this_thread::sleep_for(milliseconds(500 - (time_now - start_time)));
time_now = time_point_cast<milliseconds>(system_clock::now()).time_since_epoch().count();
} else if (time_now - start_time > 1000) {
audio.get(time_now-start_time, pcmf32);
size_t max_offset = pcmf32.size() - WHISPER_SAMPLE_RATE;
for(size_t offset=0;offset < max_offset;offset+=WHISPER_SAMPLE_RATE/10) {
std::vector<float> audio_chunk(&pcmf32[offset], &pcmf32[offset+WHISPER_SAMPLE_RATE]);
if(::vad_simple(audio_chunk, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
pcmf32.resize(offset+WHISPER_SAMPLE_RATE);
if (offset*1000/WHISPER_SAMPLE_RATE+1000 > maxlength_ms) {
//remove samples from the beginning
pcmf32.erase(pcmf32.begin(),pcmf32.end()-(maxlength_ms*WHISPER_SAMPLE_RATE/1000));
fprintf(stderr, "Shortened samples");
}
return start_time + offset*1000/WHISPER_SAMPLE_RATE+1000;
}
}
}
size_t window_duration = std::max((uint64_t)1000, time_now-start_time);
audio.get(window_duration, pcmf32);
while (!::vad_simple(pcmf32, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
std::this_thread::sleep_for(milliseconds(100));
time_now = time_point_cast<milliseconds>(system_clock::now()).time_since_epoch().count();
window_duration = std::max((uint64_t)1000,time_now-start_time);
audio.get(window_duration, pcmf32);
}
if (time_now - start_time > maxlength_ms) {
audio.get(maxlength_ms, pcmf32);
} else {
audio.get(time_now - start_time, pcmf32);
}
return time_now;
}
json unguided_transcription(struct whisper_context * ctx, audio_async &audio, json jparams, const whisper_params &params) {
std::vector<whisper_token> prompt_tokens;
std::vector<float> pcmf32;
uint64_t unprocessed_audio_timestamp = wait_for_vad(audio, jparams, params, 10000U, pcmf32);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
if (jparams.contains("prompt")) {
// unlikely to see much use. Under normal circumstances, no_context would be set to false
std::string prompt = jparams.at("prompt");
prompt_tokens.resize(1024);
int n = whisper_tokenize(ctx, prompt.c_str(), prompt_tokens.data(), 1024);
prompt_tokens.resize(n);
wparams.prompt_tokens = prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.size();
}
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = false;
wparams.translate = params.translate;
wparams.no_context = jparams.value("no_context", true);
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
throw json{
{"code", -32803},
{"message", "ERROR: whisper_full() failed"}
};
}
std::string result = whisper_full_get_segment_text(ctx,0);
return json {
{"transcription", result},
{"timestamp", unprocessed_audio_timestamp}
};
}
// command-list mode
// guide the transcription to match the most likely command from a provided list
json guided_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params, json jparams, std::vector<struct commandset> commandset_list) {
struct commandset cs = commandset_list[jparams.value("commandset_index", commandset_list.size()-1)];
std::vector<float> pcmf32;
uint64_t unprocessed_audio_timestamp = wait_for_vad(audio, jparams, params, 2000U, pcmf32);
fprintf(stderr, "%s: Speech detected! Processing ...\n", __func__);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = false;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// TODO: Do some time testing. Does an overly long prompt slow down processing?
// Set up command sets/precompute prompts
wparams.prompt_tokens = cs.prompt_tokens.data();
wparams.prompt_n_tokens = cs.prompt_tokens.size();
// TODO: properly expose as option
wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
throw json{
{"code", -32803},
{"message", "ERROR: whisper_full() failed"}//TODO: format string (sprintf?)
};
}
// estimate command probability
// NOTE: not optimal
{
const auto * logits = whisper_get_logits(ctx);
std::vector<float> probs(whisper_n_vocab(ctx), 0.0f);
// compute probs from logits via softmax
{
float max = -1e9;
for (int i = 0; i < (int) probs.size(); ++i) {
max = std::max(max, logits[i]);
}
float sum = 0.0f;
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] = expf(logits[i] - max);
sum += probs[i];
}
for (int i = 0; i < (int) probs.size(); ++i) {
probs[i] /= sum;
}
}
std::vector<std::pair<float, int>> probs_id;
// In my testing, the most verbose token is always the desired.
// TODO: Trim commandset struct once efficacy has been verified
for (int i = 0; i < (int) cs.commands.size(); ++i) {
probs_id.emplace_back(probs[cs.commands[i].tokens[0]], i);
}
// sort descending
{
using pair_type = decltype(probs_id)::value_type;
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
int id = probs_id[0].second;
return json{
{"command_index", id},
{"command_text", cs.commands[id].plaintext},
{"timestamp", unprocessed_audio_timestamp},
};
}
}
json register_commandset(struct whisper_context * ctx, json jparams, std::vector<struct commandset> &commandset_list) {
// TODO: check for token collision
struct commandset cs;
std::string k_prompt = " select one from the available words: ";
std::set<whisper_token> token_set;
whisper_token tokens[32];
for (std::string s : jparams) {
std::vector<whisper_token> token_vec;
// The existing command implementation uses a nested for loop to tokenize single characters
// I fail to see the purpose of this when ' a' has a wholly different pronunciation than the start of ' apple'
const int n = whisper_tokenize(ctx, (" " + s).c_str(), tokens, 32);
if (n < 0) {
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, s.c_str());
return 3;
}
token_vec.push_back(tokens[0]);
if (!token_set.insert(tokens[0]).second) {
fprintf(stderr, "%s: warning: %s is a duplicate of an existing token\n", __func__, s.c_str());
throw json{
{"code",-31000},
{"message", "Duplicate token in token set: " + s}
};
}
if (n > 1) {// empty string if n=0? Should never occur
fprintf(stderr, "%s: error: command is more than a single token: %s\n", __func__, s.c_str());
}
struct command command = {token_vec, s};
cs.commands.push_back(command);
k_prompt += s;
}
k_prompt = k_prompt.substr(0,k_prompt.length()-2) + ". Selected word:";
cs.prompt_tokens.resize(1024);
int n = whisper_tokenize(ctx, k_prompt.c_str(), cs.prompt_tokens.data(), 1024);
cs.prompt_tokens.resize(n);
// prepare response
int index = commandset_list.size();
commandset_list.push_back(cs);
return json{{"index",index}};
}
json seek(struct whisper_context * ctx, audio_async &audio, json params) {
// whisper_state has the pertinent offsets, but there also seem to be a large
// number of scratch buffers that would prevent rewinding context in a manner similar to llama
// I'll give this a another pass once everything else is implemented,
// but for now, it's unsupported
throw json{
{"code", -32601},
{"message", "Seeking is not yet supported."}
};
}
json parse_job(const json &body, struct whisper_context * ctx, audio_async &audio, const whisper_params &params, std::vector<struct commandset> &commandset_list) {
// See: https://www.jsonrpc.org/specification
json id = body.at("id");
try {
std::string version = body.at("jsonrpc");
if (version != "2.0") {
// unsupported version
throw json{
{"code", -3260},
{"message", "invalid jsonrpc version"}
};
}
std::string method = body.at("method");
json jparams = json{{"dummy", "dummy"}};
if (body.contains("params"))
jparams = body.at("params");
json res;
// TODO: be consistent about argument order
fprintf(stderr, "Dispatching a job\n");
if (method == "unguided") { res = unguided_transcription(ctx, audio, jparams, params); }
else if (method == "guided") { res = guided_transcription(ctx, audio, params, jparams, commandset_list); }
else if (method == "seek") { res = seek(ctx, audio, jparams); }
else if (method == "registerCommandset") { res = register_commandset(ctx, jparams, commandset_list); }
else if (method == "echo") { res = jparams; }
return json{
{"jsonrpc", "2.0"},
{"result", res},
{"id", id}
};
} catch(json ex) {
return json {
{"jsonrpc", "2.0"},
{"error", ex},
{"id", id}
};
}
}
void process_loop(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
std::deque<json> jobqueue;
std::vector<struct commandset> commandset_list;
while (true) {
// For eventual cancellation support, shouldn't block if job exists
if (std::cin.rdbuf()->in_avail() > 22 || jobqueue.size() == 0) {
int content_length;
if (scanf("Content-Length: %d", &content_length) != 1) {
fprintf(stderr, "Could not read input: %d", std::cin.peek());
return;
}
// scanf leaves the new lines intact
std::cin.ignore(2);
if (std::cin.peek() != 13) {
// Content-Type. jsonrpc necessitates utf8.
std::cin.ignore(200,10);
}
std::cin.ignore(2);
// A message is being sent and blocking is acceptable
std::string content(content_length,'\0');
std::cin.read(&content[0], content_length);
json job = json::parse(content);
// TODO: Some messages(cancellation) should skip queue here
if (job.is_array()) {
// response must also be batched. Will implement later
// for (subjob : job.begin())
// TODO: At the very least respond with an unsupported error.
} else {
jobqueue.push_back(job);
}
}
assert(jobqueue.size() > 0);
json job = jobqueue.front();
json resp = parse_job(job, ctx, audio, params, commandset_list);
if (resp != "unfinished") {
jobqueue.pop_front();
// send response
std::string data = resp.dump(-1, ' ', false, json::error_handler_t::replace);
fprintf(stdout, "Content-Length: %d\r\n\r\n%s\n", data.length()+1, data.c_str());
std::cout.flush();
}
}
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
// TODO: Investigate why this is required. An extra second of startup latency is not great
// wait for 1 second to avoid any buffered noise
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
audio.clear();
// TODO: consider some sort of indicator to designate loading has finished?
// Potentially better for the client to just start with a non-blocking message (register commands)
process_loop(ctx, audio, params);
audio.pause();
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -1,362 +0,0 @@
if !exists("g:whisper_dir")
let g:whisper_dir = expand($WHISPER_CPP_HOME)
if g:whisper_dir == ""
echoerr "Please provide a path to the whisper.cpp repo in either the $WHISPER_CPP_HOME environment variable, or g:whisper_dir"
endif
endif
if !exists("g:whisper_lsp_path")
let g:whisper_lsp_path = g:whisper_dir .. "lsp"
if !filereadable(g:whisper_lsp_path)
echoerr "Was not able to locate a lsp executable at: " .. g:whisper_lsp_path
throw "Executable not found"
endif
endif
if !exists("g:whisper_model_path")
" TODO: allow custom paths relative to the repo dir
let g:whisper_model_path = g:whisper_dir .. "models/ggml-base.en.bin"
if !filereadable(g:whisper_model_path)
echoerr "Could not find model at: " .. g:whisper_model_path
throw "Model not found"
endif
endif
let s:output_buffer = bufnr("whisper_log", v:true)
call setbufvar(s:output_buffer,"&buftype","nofile")
let s:lsp_command = [g:whisper_lsp_path,"-m",g:whisper_model_path]
" For faster execution. TODO: server load multiple models/run multiple servers?
" let s:lsp_command = [g:whisper_lsp_path, "-m", g:whisper_dir .. "models/ggml-tiny.en.bin", "-ac", "128"]
" requestCommands([params_dict])
func whisper#requestCommands(...)
let l:req = {"method": "guided", "params": {"commandset_index": 0}}
if a:0 > 0
call extend(l:req.params, a:1)
endif
let resp = ch_sendexpr(g:lsp_job, l:req, {"callback": function("s:commandCallback", [l:req.params, 0])})
endfunction
" doTranscription([params_dict])
func whisper#doTranscription(...)
let l:req = {"method": "unguided", "params": {}}
if a:0 > 0
call extend(l:req.params, a:1)
endif
let resp = ch_sendexpr(g:lsp_job, l:req, {"callback": function("s:transcriptionCallback", [function("s:insertText"),function("s:endTranscription")])})
endfunction
" For testing
func whisper#uppertest(cha)
echo tr(a:cha, s:c_lowerkeys, s:c_upperkeys)
endfunction
" (upper, exit, count, motion, command, insert/append, save run) "base"
" (upper, exit, count, motion, command, inside/around) "motion/visual"
" (upper, exit, count, motion, line, inside/around) "command already entered"
" (upper, exit, key, ) "from/till"
" upper and lower keys is used to translate between cases with tr
" Must be sunchronized
let s:c_lowerkeys = "1234567890-=qwertyuiop[]\\asdfghjkl;'zxcvbnm,./\""
let s:c_upperkeys = "!@#$%^&*()_+QWERTYUIOP{}|ASDFGHJKL:\"ZXCVBNM<>?'"
let s:c_count = split("1234567890\"",'\zs')
let s:c_command = split("ryuogpdxcv.iam", '\zs')
let s:c_motion = split("wetf'hjklnb$^)",'\zs')
" object words: Word, Sentence, Paragraph, [, (, <, Tag, {. ", '
let s:c_area = split("wsp])>t}\"'",'\zs')
"Special commands.
let s:c_special_always = ["exit", "upper"]
let s:c_special_normal = ["save", "run", "space"]
" If not in dict, key is spoken word,
" If key resolves to string, value is used for normal/motion, but key for chars
" If key resolves to dict, {0: "normal",1: "motion",2:"single char",3: "area"}
" Missing entries fall back as follows {0: "required", 1: 0, 2: "key", 3: 0}
let s:spoken_dict = {"w": "word", "e": "end", "r": "replace", "t": {0: "till", 3: "tag"}, "y": "yank", "u": "undo", "i": {0: "insert", 1: "inside"}, "o": "open", "p": {0: "paste", 3: "paragraph"}, "a": {0: "append", 1: "around"}, "s": {0: "substitute", 3: "sentence"}, "d": "delete", "f": "from", "g": "go", "h": "left", "j": "down", "k": "up", "l": "right", "c": "change", "v": "visual", "b": "back", "n": "next", "m": "mark", ".": {0: "repeat", 2: "period"}, "]": {0: "bracket", 2: "bracket"}, "'": {0: "jump", 2: "apostrophe", 3: "apostrophe"}, '"': {0: 'register', 2: "quotation", 3: "quotation"}, "-": {0: "minus", 2: "minus"}, "$": {0: "dollar", 2: "dollar"}, "^": {0: "carrot", 2: "carrot"}, ")": {0: "sentence", 2: "parenthesis", 3: "parenthesis"}, "}": {0: "paragraph", 2: "brace", 3: "brace"}, ">": {0: "indent", 2: "angle", 3: "angle"}}
" Give this another pass. This seems overly hacky even if it's functional
let s:sub_tran_msg = ""
func s:subTranProg(msg)
if s:sub_tran_msg != ""
let s:sub_tran_msg = s:sub_tran_msg .. a:msg
if mode() !=? 'v'
exe "normal" "u" .. s:sub_tran_msg
endif
else
if s:command_backlog == ""
" this should not occur
call s:logCallback(0, "Warning: Encountered sub transcription without prior command")
let s:command_backlog = "a"
endif
if a:msg[0] == ' '
let s:sub_tran_msg = s:command_backlog .. a:msg[1:-1]
else
let s:sub_tran_msg = s:command_backlog .. a:msg
endif
if mode() !=? 'v'
exe "normal" s:sub_tran_msg
endif
endif
call appendbufline(s:output_buffer, "$", s:sub_tran_msg .. ":" .. string(a:msg ))
endfunction
func s:subTranFinish(params, timestamp)
let s:repeat_command = s:sub_tran_msg
" Visual selection is lot if used with streaming, so streaming of partial
" transcriptions is disabled in visual mode
if mode() ==? 'v'
exe "normal" s:sub_tran_msg
endif
let s:sub_tran_msg = ""
let s:command_backlog = ""
exe "normal a\<C-G>u"
let l:params = a:params
let l:params.timestamp = a:timestamp
if exists("l:params.commandset_index")
unlet l:params.commandset_index
endif
call whisper#requestCommands(a:params)
endfunction
func s:logCallback(channel, msg)
call appendbufline(s:output_buffer,"$",a:msg)
endfunction
func s:transcriptionCallback(progressCallback, finishedCallback, channel, msg)
let l:tr = a:msg.result.transcription
let l:ex_ind = match(tolower(l:tr),"exit", len(l:tr)-6)
" The worst case I've observed so far is " Exit.", which is 6 characters
if l:ex_ind != -1
call a:progressCallback(strpart(l:tr,0,l:ex_ind-1))
call a:finishedCallback(a:msg.result.timestamp)
else
call a:progressCallback(l:tr)
let req = {"method": "unguided", "params": {"timestamp": a:msg.result.timestamp, "no_context": v:true}}
let resp = ch_sendexpr(g:lsp_job, req, {"callback": function("s:transcriptionCallback", [a:progressCallback, a:finishedCallback])})
endif
endfunc
func s:insertText(msg)
exe "normal a" .. a:msg
endfunction
func s:endTranscription(timestamp)
call appendbufline(s:output_buffer, "$", "Ending unguided transcription")
endfunction
" If a command does not include a whole actionable step, attempting to execute
" it discards the remainder of things. There is likely a simpler solution,
" but it can be made functional now by storing a backbuffer until actionable
let s:command_backlog = ""
let s:repeat_command = ""
let s:preceeding_upper = v:false
func s:commandCallback(params, commandset_index, channel, msg)
let l:command_index = a:msg.result.command_index
let l:do_execute = v:false
let l:next_mode = a:commandset_index
let l:command = s:commandset_list[a:commandset_index][l:command_index]
call s:logCallback(0, string(a:msg) .. " " .. a:commandset_index .. " " .. l:command)
if l:command_index == 0
"exit
"if s:command_backlog == ""
call s:logCallback(0,"Stopping command mode")
echo "No longer listening"
let s:command_backlog = ""
return
"else
" Legacy code to clear an existing buffer with exit.
" Was found to be rarely desired and is better introduced as a
" standalone command (clear?)
" call s:logCallback(0,"Clearing command_backlog" .. s:command_backlog)
" let s:command_backlog = ""
" let s:preceeding_upper = v:false
" endif
elseif l:command_index == 1
" upper
let s:preceeding_upper = !s:preceeding_upper
elseif l:command == "save"
" save and run can only happen in commandset 0,
exe "w"
elseif l:command == "run"
exe "make run"
elseif l:command == "space"
exe "normal i \<ESC>l"
elseif has_key(s:c_user, l:command)
let Userfunc = s:c_user[l:command]
if type(Userfunc) == v:t_string
let Userfunc = function(Userfunc)
endif
call Userfunc()
else
if s:preceeding_upper
" Upper should keep commandset
let s:preceeding_upper = v:false
let l:visual_command = tr(l:command, s:c_lowerkeys, s:c_upperkeys)
else
let l:visual_command = l:command
endif
echo s:command_backlog .. " - " .. l:visual_command
let s:command_backlog = s:command_backlog .. l:visual_command
if a:commandset_index == 2 || a:commandset_index == 3
" single key, either completes motion, replace, or register
" Should move to execute unless part of a register
" Change will be caught at execute
if s:command_backlog[-2:-2] !=# '"'
call s:logCallback(0,"not register")
let l:do_execute = v:true
end
let l:next_mode = 0
" commandset index only matters for a/i
elseif (l:command == "a" || l:command == "i") && a:commandset_index == 1
" inside/around. Is commandset 3
let l:next_mode = 3
elseif l:command ==# '"'
let l:next_mode = 2
elseif index(s:c_count, l:command) != -1
let l:next_mode = a:commandset_index
elseif index(s:c_motion, l:command) != -1
if l:command == 't' || l:command == 'f' || l:command == "'"
" prompt single key
let l:next_mode = 2
else
let l:do_execute = v:true
let l:next_mode = 0
endif
elseif index(s:c_command, l:command) != -1
if index(["y","g","d","c"], s:command_backlog[-1:-1]) != -1 && s:command_backlog[-1:-1] != s:command_backlog[-2:-2] && mode() !=? 'v'
" need motion or repeated command
" Potential for bad state here if disparaging command keys are
" entered (i.e. yd), but vim can handle checks for this at exe
" And checking for cases like y123d would complicate things
let l:next_mode = 1
elseif index(["i","a","c", "o", "s"], l:command) != -1 || s:command_backlog[-1:-1] ==# 'R'
"'Insert' mode, do general transcription
let l:req = {"method": "unguided", "params": a:params}
let l:req.params.timestamp = a:msg.result.timestamp
let l:req.params.no_context = v:true
let resp = ch_sendexpr(g:lsp_job, req, {"callback": function("s:transcriptionCallback", [function("s:subTranProg"), function("s:subTranFinish", [a:params])])})
return
elseif l:command == 'r' || l:command == 'm'
let l:next_mode = 2
elseif l:command == '.'
let l:next_mode = 0
let l:do_execute = v:true
let s:command_backlog = s:command_backlog[0:-2] .. s:repeat_command
else
if l:command ==? 'v'
let l:next_mode = 1
else
let l:next_mode = 0
endif
let l:do_execute = v:true
endif
else
throw "Invalid command state: " .. l:command .. " " .. a:commandset_index .. " " .. s:command_backlog
endif
endif
if l:do_execute
if mode() ==?'v' && l:next_mode == 0
let l:next_mode = 1
elseif match(s:command_backlog, 'c') != -1
let l:req = {"method": "unguided", "params": a:params}
let l:req.params.timestamp = a:msg.result.timestamp
let l:req.params.no_context = v:true
let resp = ch_sendexpr(g:lsp_job, req, {"callback": function("s:transcriptionCallback", [function("s:subTranProg"), function("s:subTranFinish", [a:params])])})
return
endif
exe "normal" s:command_backlog
if index(s:c_motion + ["u"],l:command) == -1
exe "normal a\<C-G>u"
let s:repeat_command = s:command_backlog
call s:logCallback(0, s:command_backlog)
endif
let s:command_backlog = ""
endif
let l:req = {"method": "guided", "params": a:params}
let l:req.params.timestamp = a:msg.result.timestamp
let l:req.params.commandset_index = l:next_mode
let resp = ch_sendexpr(g:lsp_job, l:req, {"callback": function("s:commandCallback",[a:params, l:next_mode])})
endfunction
func s:loadedCallback(channel, msg)
echo "Loading complete"
call s:logCallback(a:channel, a:msg)
endfunction
func s:registerCommandset(commandlist, is_final)
let req = {"method": "registerCommandset"}
let req.params = a:commandlist
call s:logCallback(0, join(a:commandlist))
call add(g:whisper_commandlist_spoken, a:commandlist)
if a:is_final
let resp = ch_sendexpr(g:lsp_job, req, {"callback": "s:loadedCallback"})
else
let resp = ch_sendexpr(g:lsp_job, req, {"callback": "s:logCallback"})
endif
endfunction
func s:registerAllCommands()
let l:normal = s:c_special_always + s:c_special_normal + s:c_count + s:c_command + s:c_motion + keys(s:c_user)
let l:visual = s:c_special_always + s:c_count + s:c_command + s:c_motion
" Currently the same as visual.
" let l:post_command = s:c_special_always + s:c_count + s:c_command + s:c_motion
let l:single_key = s:c_special_always + split(s:c_lowerkeys, '\zs')
let l:area = s:c_special_always + s:c_area
" Used only for compatibility with the testing script
let g:whisper_commandlist_spoken = []
let s:commandset_list = [l:normal, l:visual, l:single_key, l:area]
call s:registerCommandset(s:commandsetToSpoken(l:normal, 0), v:false)
call s:registerCommandset(s:commandsetToSpoken(l:visual, 1), v:false)
call s:registerCommandset(s:commandsetToSpoken(l:single_key, 2), v:false)
call s:registerCommandset(s:commandsetToSpoken(l:area, 3), v:true)
endfunction
func s:commandsetToSpoken(commandset, spoken_index)
let l:spoken_list = []
for l:command in a:commandset
if has_key(s:spoken_dict, l:command)
let l:spoken_value = s:spoken_dict[l:command]
if type(l:spoken_value) == v:t_dict
if has_key(l:spoken_value, a:spoken_index)
let l:spoken_value = l:spoken_value[a:spoken_index]
else
if a:spoken_index == 2
let l:spoken_value = l:command
else
let l:spoken_value = l:spoken_value[0]
endif
endif
else
if a:spoken_index == 2
let l:spoken_value = l:command
endif
endif
else
let l:spoken_value = l:command
endif
call add(l:spoken_list, l:spoken_value)
endfor
return l:spoken_list
endfunction
" TODO: Check lifetime. If the script is resourced, is the existing
" s:lsp_job dropped and therefore killed?
" This seems to not be the case and I've had to deal with zombie processes
" that survive exiting vim, even though said behavior conflicts with my
" understanding of the provided documentation
let s:lsp_opts = {"in_mode": "lsp", "out_mode": "lsp", "err_mode": "nl", "err_io": "buffer", "err_buf": s:output_buffer}
if !exists("g:lsp_job")
if exists("g:whisper_user_commands")
let s:c_user = g:whisper_user_commands
else
let s:c_user = {}
endif
let g:lsp_job = job_start(s:lsp_command, s:lsp_opts)
if job_status(g:lsp_job) == "fail"
echoerr "Failed to start whisper job"
endif
call s:registerAllCommands()
endif

View File

@ -10,10 +10,6 @@
#include <vector>
#include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green.
const std::vector<std::string> k_colors = {
@ -59,7 +55,6 @@ struct whisper_params {
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 2;
@ -69,36 +64,28 @@ struct whisper_params {
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
bool log_score = false;
bool speed_up = false;
bool translate = false;
bool detect_language= false;
bool diarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool output_txt = false;
bool output_vtt = false;
bool output_srt = false;
bool output_wts = false;
bool output_csv = false;
bool output_jsn = false;
bool output_lrc = false;
bool print_special = false;
bool print_colors = false;
bool print_progress = false;
bool no_timestamps = false;
std::string language = "en";
std::string language = "en";
std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::string openvino_encode_device = "CPU";
std::string model = "models/ggml-base.en.bin";
std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
@ -124,45 +111,41 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -192,11 +175,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
@ -210,14 +191,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, "\n");
}
@ -225,50 +204,8 @@ struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
int progress_prev;
};
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "0";
} else if (energy1 > 1.1*energy0) {
speaker = "1";
} else {
speaker = "?";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
if (!id_only) {
speaker.insert(0, "(speaker ");
speaker.append(")");
}
return speaker;
}
void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
if (progress >= *progress_prev + progress_step) {
*progress_prev += progress_step;
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
}
}
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -298,7 +235,28 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "(speaker 0)";
} else if (energy1 > 1.1*energy0) {
speaker = "(speaker 1)";
} else {
speaker = "(speaker ?)";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
}
if (params.print_colors) {
@ -323,12 +281,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
printf("%s%s", speaker.c_str(), text);
}
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
@ -338,7 +290,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
}
}
bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
bool output_txt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -350,22 +302,13 @@ bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << speaker << text << "\n";
fout << text << "\n";
}
return true;
}
bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
bool output_vtt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -381,23 +324,15 @@ bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
speaker.insert(0, "<v Speaker");
speaker.append(">");
}
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
fout << speaker << text << "\n\n";
fout << text << "\n\n";
}
return true;
}
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -411,16 +346,10 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << i + 1 + params.offset_n << "\n";
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
fout << speaker << text << "\n\n";
fout << text << "\n\n";
}
return true;
@ -457,7 +386,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
return escaped;
}
bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
bool output_csv(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -467,13 +396,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
fout << "start,end,";
if (params.diarize && pcmf32s.size() == 2)
{
fout << "speaker,";
}
fout << "text\n";
fout << "start,end,text\n";
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
@ -481,37 +404,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
char * text_escaped = escape_double_quotes_and_backslashes(text);
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
fout << 10 * t0 << "," << 10 * t1 << ",";
if (params.diarize && pcmf32s.size() == 2)
{
fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
}
fout << "\"" << text_escaped << "\"\n";
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
}
return true;
}
bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
const int n_segments = whisper_full_n_segments(ctx);
// fprintf(stderr,"segments: %d\n",n_segments);
for (int i = 0; i < n_segments; ++i) {
const int n_tokens = whisper_full_n_tokens(ctx, i);
// fprintf(stderr,"tokens: %d\n",n_tokens);
for (int j = 0; j < n_tokens; j++) {
auto token = whisper_full_get_token_text(ctx, i, j);
auto probability = whisper_full_get_token_p(ctx, i, j);
fout << token << '\t' << probability << std::endl;
// fprintf(stderr,"token: %s %f\n",token,probability);
}
}
return true;
}
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
std::ofstream fout(fname);
int indent = 0;
@ -525,13 +424,13 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
indent++;
};
auto end_arr = [&](bool end) {
auto end_arr = [&](bool end = false) {
indent--;
doindent();
fout << (end ? "]\n" : "},\n");
};
auto start_obj = [&](const char *name) {
auto start_obj = [&](const char *name = nullptr) {
doindent();
if (name) {
fout << "\"" << name << "\": {\n";
@ -541,7 +440,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
indent++;
};
auto end_obj = [&](bool end) {
auto end_obj = [&](bool end = false) {
indent--;
doindent();
fout << (end ? "}\n" : "},\n");
@ -552,24 +451,24 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
fout << "\"" << name << "\": ";
};
auto value_s = [&](const char *name, const char *val, bool end) {
auto value_s = [&](const char *name, const char *val, bool end = false) {
start_value(name);
char * val_escaped = escape_double_quotes_and_backslashes(val);
fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n");
free(val_escaped);
};
auto end_value = [&](bool end) {
auto end_value = [&](bool end = false) {
fout << (end ? "\n" : ",\n");
};
auto value_i = [&](const char *name, const int64_t val, bool end) {
auto value_i = [&](const char *name, const int64_t val, bool end = false) {
start_value(name);
fout << val;
end_value(end);
};
auto value_b = [&](const char *name, const bool val, bool end) {
auto value_b = [&](const char *name, const bool val, bool end = false) {
start_value(name);
fout << (val ? "true" : "false");
end_value(end);
@ -581,62 +480,53 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
}
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
start_obj(nullptr);
value_s("systeminfo", whisper_print_system_info(), false);
start_obj();
value_s("systeminfo", whisper_print_system_info());
start_obj("model");
value_s("type", whisper_model_type_readable(ctx), false);
value_b("multilingual", whisper_is_multilingual(ctx), false);
value_i("vocab", whisper_model_n_vocab(ctx), false);
value_s("type", whisper_model_type_readable(ctx));
value_b("multilingual", whisper_is_multilingual(ctx));
value_i("vocab", whisper_model_n_vocab(ctx));
start_obj("audio");
value_i("ctx", whisper_model_n_audio_ctx(ctx), false);
value_i("state", whisper_model_n_audio_state(ctx), false);
value_i("head", whisper_model_n_audio_head(ctx), false);
value_i("ctx", whisper_model_n_audio_ctx(ctx));
value_i("state", whisper_model_n_audio_state(ctx));
value_i("head", whisper_model_n_audio_head(ctx));
value_i("layer", whisper_model_n_audio_layer(ctx), true);
end_obj(false);
end_obj();
start_obj("text");
value_i("ctx", whisper_model_n_text_ctx(ctx), false);
value_i("state", whisper_model_n_text_state(ctx), false);
value_i("head", whisper_model_n_text_head(ctx), false);
value_i("ctx", whisper_model_n_text_ctx(ctx));
value_i("state", whisper_model_n_text_state(ctx));
value_i("head", whisper_model_n_text_head(ctx));
value_i("layer", whisper_model_n_text_layer(ctx), true);
end_obj(false);
value_i("mels", whisper_model_n_mels(ctx), false);
end_obj();
value_i("mels", whisper_model_n_mels(ctx));
value_i("ftype", whisper_model_ftype(ctx), true);
end_obj(false);
end_obj();
start_obj("params");
value_s("model", params.model.c_str(), false);
value_s("language", params.language.c_str(), false);
value_s("model", params.model.c_str());
value_s("language", params.language.c_str());
value_b("translate", params.translate, true);
end_obj(false);
end_obj();
start_obj("result");
value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true);
end_obj(false);
end_obj();
start_arr("transcription");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
start_obj(nullptr);
start_obj();
start_obj("timestamps");
value_s("from", to_timestamp(t0, true).c_str(), false);
value_s("from", to_timestamp(t0, true).c_str());
value_s("to", to_timestamp(t1, true).c_str(), true);
end_obj(false);
end_obj();
start_obj("offsets");
value_i("from", t0 * 10, false);
value_i("from", t0 * 10);
value_i("to", t1 * 10, true);
end_obj(false);
value_s("text", text, !params.diarize && !params.tinydiarize);
if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
}
if (params.tinydiarize) {
value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
}
end_obj();
value_s("text", text, true);
end_obj(i == (n_segments - 1));
}
@ -648,7 +538,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
// karaoke video generation
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
// TODO: font parameter adjustments
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
std::ofstream fout(fname);
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
@ -685,11 +575,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
bool is_first = true;
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
for (int j = 0; j < n; ++j) {
const auto & token = tokens[j];
@ -698,19 +583,13 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
continue;
}
std::string txt_bg = "";
std::string txt_fg = ""; // highlight token
std::string txt_ul = ""; // underline
std::string txt_bg;
std::string txt_fg; // highlight token
std::string txt_ul; // underline
if (params.diarize && pcmf32s.size() == 2) {
txt_bg = speaker;
txt_fg = speaker;
txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
}
txt_bg.append("> ");
txt_fg.append("> ");
txt_ul.append("\\ \\ ");
txt_bg = "> ";
txt_fg = "> ";
txt_ul = "\\ \\ ";
{
for (int k = 0; k < n; ++k) {
@ -773,7 +652,8 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
return true;
}
bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
bool output_lrc(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname);
if (!fout.is_open()) {
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
@ -798,16 +678,8 @@ bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_
char buf[16];
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
std::string timestamp_lrc = std::string(buf);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
fout << '[' << timestamp_lrc << ']' << text << "\n";
}
return true;
@ -817,7 +689,6 @@ int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
whisper_print_usage(argc, argv, params);
return 1;
}
@ -833,12 +704,6 @@ int main(int argc, char ** argv) {
exit(0);
}
if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@ -848,9 +713,6 @@ int main(int argc, char ** argv) {
return 3;
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
@ -883,12 +745,11 @@ int main(int argc, char ** argv) {
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
@ -918,9 +779,6 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.initial_prompt = params.prompt.c_str();
@ -931,7 +789,7 @@ int main(int argc, char ** argv) {
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
@ -939,11 +797,6 @@ int main(int argc, char ** argv) {
wparams.new_segment_callback_user_data = &user_data;
}
if (wparams.print_progress) {
wparams.progress_callback = whisper_print_progress_callback;
wparams.progress_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
@ -970,49 +823,43 @@ int main(int argc, char ** argv) {
// output to text file
if (params.output_txt) {
const auto fname_txt = fname_out + ".txt";
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
output_txt(ctx, fname_txt.c_str());
}
// output to VTT file
if (params.output_vtt) {
const auto fname_vtt = fname_out + ".vtt";
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
output_vtt(ctx, fname_vtt.c_str());
}
// output to SRT file
if (params.output_srt) {
const auto fname_srt = fname_out + ".srt";
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
output_srt(ctx, fname_srt.c_str(), params);
}
// output to WTS file
if (params.output_wts) {
const auto fname_wts = fname_out + ".wts";
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
}
// output to CSV file
if (params.output_csv) {
const auto fname_csv = fname_out + ".csv";
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
output_csv(ctx, fname_csv.c_str());
}
// output to JSON file
if (params.output_jsn) {
const auto fname_jsn = fname_out + ".json";
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
output_json(ctx, fname_jsn.c_str(), params);
}
// output to LRC file
if (params.output_lrc) {
const auto fname_lrc = fname_out + ".lrc";
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
}
// output to score file
if (params.log_score) {
const auto fname_score = fname_out + ".score.txt";
output_score(ctx, fname_score.c_str(), params, pcmf32s);
output_lrc(ctx, fname_lrc.c_str());
}
}
}

View File

@ -25,7 +25,7 @@ struct whisper_hparams {
int32_t n_text_head = 6;
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t ftype = 1;
int32_t f16 = 1;
};
struct whisper_filters {
@ -57,7 +57,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
{
uint32_t magic;
finp.read((char *) &magic, sizeof(magic));
if (magic != GGML_FILE_MAGIC) {
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
return false;
}
@ -79,10 +79,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
finp.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
finp.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
finp.read((char *) &hparams.f16, sizeof(hparams.f16));
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
@ -94,22 +91,19 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype);
fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src);
fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst);
fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fout.write((const char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((const char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((const char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((const char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((const char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((const char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fout.write((const char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fout.write((const char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fout.write((const char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fout.write((const char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((const char *) &ftype_dst, sizeof(hparams.ftype));
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fout.write((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fout.write((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fout.write((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fout.write((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fout.write((char *) &ftype, sizeof(hparams.f16));
}
// load mel filters
@ -138,17 +132,15 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
// return false;
//}
char word[129];
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
finp.read ((char *) &len, sizeof(len));
fout.write((char *) &len, sizeof(len));
word[len] = '\0';
finp.read ((char *) word, len);
fout.write((char *) word, len);
word.resize(len);
finp.read ((char *) word.data(), len);
fout.write((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;

View File

@ -47,7 +47,6 @@ struct whisper_params {
bool print_special = false;
bool no_context = true;
bool no_timestamps = false;
bool tinydiarize = false;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
@ -81,8 +80,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
@ -116,7 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, "\n");
}
@ -303,8 +299,6 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
// disable temperature fallback
//wparams.temperature_inc = -1.0f;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
@ -350,19 +344,10 @@ int main(int argc, char ** argv) {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string output = "[" + to_timestamp(t0) + " --> " + to_timestamp(t1) + "] " + text;
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
output += " [SPEAKER_TURN]";
}
output += "\n";
printf("%s", output.c_str());
fflush(stdout);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
if (params.fname_out.length() > 0) {
fout << output;
fout << "[" << to_timestamp(t0) << " --> " << to_timestamp(t1) << "] " << text << std::endl;
}
}
}

View File

@ -42,8 +42,8 @@ Example usage:
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or Windows SpeechSynthesizer, but you can use whatever you wish.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
## Discussion

View File

@ -1,20 +1,23 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk-llama/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import generate, play, save
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
# Get a Voice object, by name or UUID
voice = "Arnold" #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = generate(
text=str(sys.argv[2:]),
voice=voice
)
audio = voice.generate(str(sys.argv[2:]))
# Save the TTS to a file
save(audio, "audio.mp3")
audio.save("audio")

View File

@ -14,7 +14,6 @@
#include <string>
#include <vector>
#include <stdexcept>
#ifdef __has_include
#if __has_include(<unistd.h>)
@ -75,7 +74,7 @@ struct llama_file {
llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
throw format("failed to open %s: %s", fname, std::strerror(errno));
}
seek(0, SEEK_END);
size = tell();
@ -101,17 +100,17 @@ struct llama_file {
LLAMA_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t len) const {
if (len == 0) {
void read_raw(void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, len, 1, fp);
std::size_t ret = std::fread(ptr, size, 1, fp);
if (ferror(fp)) {
throw std::runtime_error(format("read error: %s", strerror(errno)));
throw format("read error: %s", strerror(errno));
}
if (ret != 1) {
throw std::runtime_error(std::string("unexpectedly reached end of file"));
throw std::string("unexpectedly reached end of file");
}
}
@ -127,14 +126,14 @@ struct llama_file {
return std::string(chars.data(), len);
}
void write_raw(const void * ptr, size_t len) const {
if (len == 0) {
void write_raw(const void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
size_t ret = std::fwrite(ptr, len, 1, fp);
size_t ret = std::fwrite(ptr, size, 1, fp);
if (ret != 1) {
throw std::runtime_error(format("write error: %s", strerror(errno)));
throw format("write error: %s", strerror(errno));
}
}
@ -172,7 +171,7 @@ struct llama_mmap {
#ifdef _POSIX_MAPPED_FILES
static constexpr bool SUPPORTED = true;
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
llama_mmap(struct llama_file * file, bool prefetch = true) {
size = file->size;
int fd = fileno(file->fp);
int flags = MAP_SHARED;
@ -181,13 +180,13 @@ struct llama_mmap {
#endif
addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) {
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
throw format("mmap failed: %s", strerror(errno));
}
if (prefetch > 0) {
if (prefetch) {
// Advise the kernel to preload the mapped memory
if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
if (madvise(addr, file->size, MADV_WILLNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
}
@ -208,7 +207,7 @@ struct llama_mmap {
DWORD error = GetLastError();
if (hMapping == NULL) {
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
throw format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str());
}
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
@ -216,7 +215,7 @@ struct llama_mmap {
CloseHandle(hMapping);
if (addr == NULL) {
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
throw format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str());
}
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
@ -244,9 +243,8 @@ struct llama_mmap {
#else
static constexpr bool SUPPORTED = false;
llama_mmap(struct llama_file *, bool prefetch = true) {
(void)prefetch;
throw std::runtime_error(std::string("mmap not supported"));
llama_mmap(struct llama_file *) {
throw std::string("mmap not supported");
}
#endif
};
@ -267,9 +265,9 @@ struct llama_mlock {
}
}
void init(void * ptr) {
LLAMA_ASSERT(addr == NULL && size == 0);
addr = ptr;
void init(void * addr) {
LLAMA_ASSERT(this->addr == NULL && this->size == 0);
this->addr = addr;
}
void grow_to(size_t target_size) {
@ -340,14 +338,14 @@ struct llama_mlock {
return (size_t) si.dwPageSize;
}
bool raw_lock(void * ptr, size_t len) {
bool raw_lock(void * addr, size_t size) {
for (int tries = 1; ; tries++) {
if (VirtualLock(ptr, len)) {
if (VirtualLock(addr, size)) {
return true;
}
if (tries == 2) {
fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
len, size, llama_format_win_err(GetLastError()).c_str());
size, this->size, llama_format_win_err(GetLastError()).c_str());
return false;
}
@ -363,7 +361,7 @@ struct llama_mlock {
// is equal to the number of pages in its minimum working set minus
// a small overhead."
// Hopefully a megabyte is enough overhead:
size_t increment = len + 1048576;
size_t increment = size + 1048576;
// The minimum must be <= the maximum, so we need to increase both:
min_ws_size += increment;
max_ws_size += increment;
@ -375,8 +373,8 @@ struct llama_mlock {
}
}
void raw_unlock(void * ptr, size_t len) {
if (!VirtualUnlock(ptr, len)) {
void raw_unlock(void * addr, size_t size) {
if (!VirtualUnlock(addr, size)) {
fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n",
llama_format_win_err(GetLastError()).c_str());
}
@ -384,16 +382,11 @@ struct llama_mlock {
#else
static constexpr bool SUPPORTED = false;
size_t lock_granularity() {
return (size_t) 65536;
}
bool raw_lock(const void * addr, size_t len) {
void raw_lock(const void * addr, size_t size) {
fprintf(stderr, "warning: mlock not supported on this system\n");
return false;
}
void raw_unlock(const void * addr, size_t len) {}
void raw_unlock(const void * addr, size_t size) {}
#endif
};
@ -402,70 +395,36 @@ struct llama_buffer {
uint8_t * addr = NULL;
size_t size = 0;
llama_buffer() = default;
void resize(size_t len) {
void resize(size_t size) {
delete[] addr;
addr = new uint8_t[len];
size = len;
addr = new uint8_t[size];
this->size = size;
}
~llama_buffer() {
delete[] addr;
}
// disable copy and move
llama_buffer(const llama_buffer&) = delete;
llama_buffer(llama_buffer&&) = delete;
llama_buffer& operator=(const llama_buffer&) = delete;
llama_buffer& operator=(llama_buffer&&) = delete;
};
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
struct llama_ctx_buffer {
uint8_t * addr = NULL;
bool is_cuda;
size_t size = 0;
llama_ctx_buffer() = default;
void resize(size_t size) {
free();
addr = (uint8_t *) ggml_cuda_host_malloc(size);
if (addr) {
is_cuda = true;
}
else {
// fall back to pageable memory
addr = new uint8_t[size];
is_cuda = false;
ggml_cuda_host_free(addr);
}
addr = (uint8_t *) ggml_cuda_host_malloc(size);
this->size = size;
}
void free() {
if (addr) {
if (is_cuda) {
ggml_cuda_host_free(addr);
}
else {
delete[] addr;
}
}
addr = NULL;
}
~llama_ctx_buffer() {
free();
if (addr) {
ggml_cuda_host_free(addr);
}
}
// disable copy and move
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
};
#else
typedef llama_buffer llama_ctx_buffer;

File diff suppressed because it is too large Load Diff

View File

@ -19,17 +19,11 @@
# define LLAMA_API
#endif
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_VERSION 3
#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT
#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 1
#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 'ggjt'
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
#define LLAMA_SESSION_MAGIC 'ggsn'
#define LLAMA_SESSION_VERSION 0
#ifdef __cplusplus
extern "C" {
@ -46,9 +40,9 @@ extern "C" {
typedef int llama_token;
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
llama_token id; // token id
float logit; // log-odds of the token
float p; // probability of the token
} llama_token_data;
typedef struct llama_token_data_array {
@ -60,9 +54,9 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
int n_gpu_layers; // number of layers to store in VRAM
int seed; // RNG seed, -1 for random
int n_ctx; // text context
int n_parts; // -1 for default
int seed; // RNG seed, 0 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
@ -79,16 +73,16 @@ extern "C" {
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors
// LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
};
LLAMA_API struct llama_context_params llama_context_default_params();
@ -96,13 +90,6 @@ extern "C" {
LLAMA_API bool llama_mmap_supported();
LLAMA_API bool llama_mlock_supported();
// TODO: not great API - very likely to change
// Initialize the llama + ggml backend
// Call once at the start of the program
LLAMA_API void llama_init_backend();
LLAMA_API int64_t llama_time_us();
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
@ -135,28 +122,26 @@ extern "C" {
int n_threads);
// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
// Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
// Set the state reading from the specified address
// Returns the number of bytes read
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src);
LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
// Save/load session file
LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
@ -180,9 +165,9 @@ extern "C" {
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
LLAMA_API int llama_n_embd (struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
@ -196,7 +181,7 @@ extern "C" {
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos();
@ -206,25 +191,25 @@ extern "C" {
// Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float penalty);
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep = 1);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep = 1);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep = 1);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.

View File

@ -1 +0,0 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

View File

@ -1,12 +0,0 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
# voice options are David or Zira
[Parameter(Mandatory=$true)][string]$voice,
[Parameter(Mandatory=$true)][string]$text
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$speak.SelectVoice("Microsoft $voice Desktop");
$speak.Rate="0";
$speak.Speak($text);

View File

@ -13,11 +13,8 @@
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs)
# It's possible to use the API for free with limited number of characters. To increase this limit register to https://beta.elevenlabs.io to get an api key and paste it after 'ELEVEN_API_KEY='
#Keep the line commented to use the free version whitout api key
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk-llama/eleven-labs.py
#
#export ELEVEN_API_KEY=your_api_key
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2" >/dev/null 2>&1

View File

@ -33,6 +33,8 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_parts_llama = -1;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
@ -47,7 +49,7 @@ struct whisper_params {
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk-llama/speak";
std::string speak = "./examples/talk-llama/speak.sh";
std::string prompt = "";
std::string fname_out;
std::string path_session = ""; // path to file for saving/loading model eval state
@ -70,6 +72,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
@ -120,6 +123,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
@ -235,14 +239,13 @@ int main(int argc, char ** argv) {
// llama init
llama_init_backend();
auto lparams = llama_context_default_params();
// tune these to your liking
lparams.n_ctx = 2048;
lparams.seed = 1;
lparams.f16_kv = true;
lparams.n_parts = params.n_parts_llama;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
@ -557,7 +560,7 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
embd.clear();
if (done) break;
@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
if (!path_session.empty() && need_to_save_session) {
need_to_save_session = false;
llama_save_session_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
}
llama_token id = 0;
@ -606,8 +609,8 @@ int main(int argc, char ** argv) {
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
} else {
// Temperature sampling
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
llama_sample_top_k(ctx_llama, &candidates_p, top_k);
llama_sample_top_p(ctx_llama, &candidates_p, top_p);
llama_sample_temperature(ctx_llama, &candidates_p, temp);
id = llama_sample_token(ctx_llama, &candidates_p);
}

View File

@ -37,5 +37,5 @@ wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.c
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or `espeak` or Windows SpeechSynthesizer, but you can use whatever you wish.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use `espeak`, but you can use whatever you wish.

View File

@ -1,20 +1,23 @@
import sys
import importlib.util
api_key = "" #Write your https://beta.elevenlabs.io api key here
if not api_key:
print("To use elevenlabs you have to register to https://beta.elevenlabs.io and add your elevenlabs api key to examples/talk/eleven-labs.py")
sys.exit()
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import generate, play, save
from elevenlabs import ElevenLabs
eleven = ElevenLabs(api_key)
# Get a Voice object, by name or UUID
voice = "Arnold" #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
voice = eleven.voices["Arnold"] #Possible Voices: Adam Antoni Arnold Bella Domi Elli Josh
# Generate the TTS
audio = generate(
text=str(sys.argv[2:]),
voice=voice
)
audio = voice.generate(str(sys.argv[2:]))
# Save the TTS to a file
save(audio, "audio.mp3")
audio.save("audio")

View File

@ -1 +0,0 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

View File

@ -1,12 +0,0 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
# voice options are David or Zira
[Parameter(Mandatory=$true)][string]$voice,
[Parameter(Mandatory=$true)][string]$text
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$speak.SelectVoice("Microsoft $voice Desktop");
$speak.Rate="0";
$speak.Speak($text);

5
examples/talk/speak → examples/talk/speak.sh Normal file → Executable file
View File

@ -13,11 +13,8 @@
say "$2"
# Eleven Labs
# To use it, install the elevenlabs module from pip (pip install elevenlabs)
# It's possible to use the API for free with limited number of characters. To increase this limit register to https://beta.elevenlabs.io to get an api key and paste it after 'ELEVEN_API_KEY='
#Keep the line commented to use the free version without api key
# To use it, install the elevenlabs module from pip (pip install elevenlabs), register to https://beta.elevenlabs.io to get an api key and paste it in /examples/talk/eleven-labs.py
#
#export ELEVEN_API_KEY=your_api_key
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2"

View File

@ -36,7 +36,7 @@ struct whisper_params {
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
std::string speak = "./examples/talk/speak";
std::string speak = "./examples/talk/speak.sh";
std::string fname_out;
};

View File

@ -18,9 +18,6 @@ android {
vectorDrawables {
useSupportLibrary true
}
ndk {
abiFilters 'arm64-v8a', 'armeabi-v7a', 'x86', 'x86_64'
}
}
buildTypes {
@ -45,8 +42,8 @@ android {
}
ndkVersion "25.1.8937393"
externalNativeBuild {
cmake {
path = file("src/main/jni/whisper/CMakeLists.txt")
ndkBuild {
path 'src/main/jni/whisper/Android.mk'
}
}
packagingOptions {

View File

@ -10,16 +10,12 @@ fun decodeWaveFile(file: File): FloatArray {
file.inputStream().use { it.copyTo(baos) }
val buffer = ByteBuffer.wrap(baos.toByteArray())
buffer.order(ByteOrder.LITTLE_ENDIAN)
val channel = buffer.getShort(22).toInt()
buffer.position(44)
val shortBuffer = buffer.asShortBuffer()
val shortArray = ShortArray(shortBuffer.limit())
shortBuffer.get(shortArray)
return FloatArray(shortArray.size / channel) { index ->
when (channel) {
1 -> (shortArray[index] / 32767.0f).coerceIn(-1f..1f)
else -> ((shortArray[2*index] + shortArray[2*index + 1])/ 32767.0f / 2.0f).coerceIn(-1f..1f)
}
return FloatArray(shortArray.size) { index ->
(shortArray[index] / 32767.0f).coerceIn(-1f..1f)
}
}
@ -77,4 +73,4 @@ private fun headerBytes(totalLength: Int): ByteArray {
it.get(bytes)
return bytes
}
}
}

View File

@ -0,0 +1,26 @@
LOCAL_PATH := $(call my-dir)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper
include $(LOCAL_PATH)/Whisper.mk
include $(BUILD_SHARED_LIBRARY)
ifeq ($(TARGET_ARCH_ABI),armeabi-v7a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_vfpv4
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -mfpu=neon-vfpv4
include $(BUILD_SHARED_LIBRARY)
endif
ifeq ($(TARGET_ARCH_ABI),arm64-v8a)
include $(CLEAR_VARS)
LOCAL_MODULE := libwhisper_v8fp16_va
include $(LOCAL_PATH)/Whisper.mk
# Allow building NEON FMA code.
# https://android.googlesource.com/platform/ndk/+/master/sources/android/cpufeatures/cpu-features.h
LOCAL_CFLAGS += -march=armv8.2-a+fp16
include $(BUILD_SHARED_LIBRARY)
endif

View File

@ -0,0 +1 @@
APP_STL := c++_static

View File

@ -1,53 +0,0 @@
cmake_minimum_required(VERSION 3.10)
project(whisper.cpp)
set(CMAKE_CXX_STANDARD 11)
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
set(
SOURCE_FILES
${WHISPER_LIB_DIR}/ggml.c
${WHISPER_LIB_DIR}/whisper.cpp
${CMAKE_SOURCE_DIR}/jni.c
)
find_library(LOG_LIB log)
function(build_library target_name)
add_library(
${target_name}
SHARED
${SOURCE_FILES}
)
target_link_libraries(${target_name} ${LOG_LIB} android)
if (${target_name} STREQUAL "whisper_v8fp16_va")
target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16)
elseif (${target_name} STREQUAL "whisper_vfpv4")
target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4)
endif ()
if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug")
target_compile_options(${target_name} PRIVATE -O3)
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
target_link_options(${target_name} PRIVATE -flto)
endif ()
endfunction()
build_library("whisper") # Default target
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
build_library("whisper_v8fp16_va")
elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a")
build_library("whisper_vfpv4")
endif ()
include_directories(${WHISPER_LIB_DIR})

View File

@ -0,0 +1,18 @@
WHISPER_LIB_DIR := $(LOCAL_PATH)/../../../../../../../
LOCAL_LDLIBS := -landroid -llog
# Make the final output library smaller by only keeping the symbols referenced from the app.
ifneq ($(APP_OPTIM),debug)
LOCAL_CFLAGS += -O3
LOCAL_CFLAGS += -fvisibility=hidden -fvisibility-inlines-hidden
LOCAL_CFLAGS += -ffunction-sections -fdata-sections
LOCAL_LDFLAGS += -Wl,--gc-sections
LOCAL_LDFLAGS += -Wl,--exclude-libs,ALL
LOCAL_LDFLAGS += -flto
endif
LOCAL_CFLAGS += -DSTDC_HEADERS -std=c11 -I $(WHISPER_LIB_DIR)
LOCAL_CPPFLAGS += -std=c++11
LOCAL_SRC_FILES := $(WHISPER_LIB_DIR)/ggml.c \
$(WHISPER_LIB_DIR)/whisper.cpp \
$(LOCAL_PATH)/jni.c

View File

@ -32,7 +32,7 @@ model="base.en"
# export the path to the whisper.cpp repo in the WHISPER_CPP_HOME env variable
# https://github.com/ggerganov/whisper.cpp
cd "${WHISPER_CPP_HOME}"
cd ${WHISPER_CPP_HOME}
if [ ! -f ./stream ] ; then
echo "whisper.nvim: the 'stream' executable was not found! WHISPER_CPP_HOME=${WHISPER_CPP_HOME}" > /tmp/whisper.nvim

View File

@ -14,24 +14,15 @@ https://user-images.githubusercontent.com/1991296/204126266-ce4177c6-6eca-4bd9-b
```java
git clone https://github.com/ggerganov/whisper.cpp
open whisper.cpp/examples/whisper.objc/whisper.objc.xcodeproj/
// If you don't want to convert a Core ML model, you can skip this step by create dummy model
mkdir models/ggml-base.en-encoder.mlmodelc
```
Make sure to build the project in `Release`:
<img width="947" alt="image" src="https://user-images.githubusercontent.com/1991296/197382607-9e1e6d1b-79fa-496f-9d16-b71dc1535701.png">
Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag for `ggml.c` in Build Phases.
Also, don't forget to add the `-DGGML_USE_ACCELERATE` compiler flag in Build Phases.
This can significantly improve the performance of the transcription:
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.

View File

@ -14,13 +14,9 @@
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; };
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; };
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
@ -41,13 +37,6 @@
18627C9529052C5800BD2A04 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = "<group>"; };
18627C9729052C6600BD2A04 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = "<group>"; };
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = "ggml-base.en.bin"; path = "../../../models/ggml-base.en.bin"; sourceTree = "<group>"; };
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
7FE342482A0C3FA20015A058 /* whisper-decoder-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-decoder-impl.h"; sourceTree = "<group>"; };
7FE342492A0C3FA20015A058 /* whisper-encoder-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder-impl.h"; sourceTree = "<group>"; };
7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-decoder-impl.m"; sourceTree = "<group>"; };
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = "ggml-base.en-encoder.mlmodelc"; path = "../../../models/ggml-base.en-encoder.mlmodelc"; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
@ -80,8 +69,6 @@
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
isa = PBXGroup;
children = (
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
7FE342442A0C3FA20015A058 /* coreml */,
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
18627C9729052C6600BD2A04 /* ggml.h */,
18627C9529052C5800BD2A04 /* ggml.c */,
@ -102,20 +89,6 @@
path = whisper.objc;
sourceTree = "<group>";
};
7FE342442A0C3FA20015A058 /* coreml */ = {
isa = PBXGroup;
children = (
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */,
7FE342462A0C3FA20015A058 /* whisper-encoder.h */,
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */,
7FE342482A0C3FA20015A058 /* whisper-decoder-impl.h */,
7FE342492A0C3FA20015A058 /* whisper-encoder-impl.h */,
7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */,
);
name = coreml;
path = ../../../coreml;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
@ -174,7 +147,6 @@
buildActionMask = 2147483647;
files = (
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */,
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */,
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */,
18627C8429052BDF00BD2A04 /* Main.storyboard in Resources */,
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */,
@ -189,14 +161,11 @@
buildActionMask = 2147483647;
files = (
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */,
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */,
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */,
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
18627C8C29052BE000BD2A04 /* main.m in Sources */,
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};

View File

@ -10,48 +10,36 @@ fi
qtype0="q5_0"
qtype1="q5_1"
upload="$1"
declare -a filedex
cd `dirname $0`
cd ../
# Let's loop across all the objects in the 'models' dir:
for i in ./models/*; do
# Check to see if it's a file or directory
if [ -d "$i" ]; then
# It's a directory! We should make sure it's not empty first:
if [ "$(ls -A $i)" ]; then
# Passed! Let's go searching for bin files (shouldn't need to go more than a layer deep here)
for f in "$i"/*.bin; do
# [Neuron Activation]
newfile=`echo "${f##*/}" | cut -d _ -f 1`;
if [ "$newfile" != "q5" ]; then
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" ${qtype1};
./quantize "${f}" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${i:-4}/${i:9:${#i}-4}-${qtype1}.bin" "${i:-4}/${i:9:${#i}-4}-${qtype0}.bin" )
fi
done
fi
else
# It's a file! Let's make sure it's the right type:
if [ "${i##*.}" == "bin" ]; then
# And we probably want to skip the testing files
if [ "${i:9:8}" != "for-test" ]; then
# [Neuron Activation]
./quantize "${i}" "${i:-4}-${qtype1}.bin" ${qtype1};
./quantize "${i}" "${i:-4}-${qtype0}.bin" ${qtype0};
filedex+=( "${i:-4}-${qtype1}.bin" "${i:-4}-${qtype0}.bin" )
fi
fi
fi
done
./quantize ./models/ggml-tiny.en.bin ./models/ggml-tiny.en-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-tiny.bin ./models/ggml-tiny-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-base.en.bin ./models/ggml-base.en-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-base.bin ./models/ggml-base-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-small.en.bin ./models/ggml-small.en-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-small.bin ./models/ggml-small-${qtype1}.bin ${qtype1}
./quantize ./models/ggml-medium.en.bin ./models/ggml-medium.en-${qtype0}.bin ${qtype0}
./quantize ./models/ggml-medium.bin ./models/ggml-medium-${qtype0}.bin ${qtype0}
./quantize ./models/ggml-large.bin ./models/ggml-large-${qtype0}.bin ${qtype0}
if [ "$upload" == "1" ]; then
for i in ${!filedex[@]}; do
if [ "${filedex[$i]:9:8}" != "for-test" ]; then
scp ${filedex[$i]} root@linode0:/mnt/Data/ggml/ggml-model-${filedex[$i]:9}
fi
done
scp ./models/ggml-tiny.en-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-tiny.en-${qtype1}.bin
scp ./models/ggml-tiny-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-tiny-${qtype1}.bin
scp ./models/ggml-base.en-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-base.en-${qtype1}.bin
scp ./models/ggml-base-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-base-${qtype1}.bin
scp ./models/ggml-small.en-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-small.en-${qtype1}.bin
scp ./models/ggml-small-${qtype1}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-small-${qtype1}.bin
scp ./models/ggml-medium.en-${qtype0}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-medium.en-${qtype0}.bin
scp ./models/ggml-medium-${qtype0}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-medium-${qtype0}.bin
scp ./models/ggml-large-${qtype0}.bin root@linode0:/mnt/Data/ggml/ggml-model-whisper-large-${qtype0}.bin
fi

View File

@ -4,18 +4,9 @@ cp -rpv ../ggml/src/ggml.c ./ggml.c
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
cp -rpv ../ggml/src/ggml-opencl.c ./ggml-opencl.c
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
cp -rpv ../ggml/examples/common.h ./examples/common.h
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
cp -rpv ../ggml/examples/whisper/main.cpp ./examples/main/main.cpp
cp -rpv ../ggml/examples/whisper/quantize.cpp ./examples/quantize/quantize.cpp

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,11 @@
#pragma once
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_CUDA_MAX_DEVICES 16
void ggml_init_cublas(void);
void ggml_cuda_set_tensor_split(const float * tensor_split);
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
@ -20,17 +14,6 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
void ggml_cuda_free_data(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
void ggml_cuda_set_main_device(int main_device);
void ggml_cuda_set_scratch_size(size_t scratch_size);
void ggml_cuda_free_scratch(void);
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
#ifdef __cplusplus
}
#endif

View File

@ -1,67 +0,0 @@
// An interface allowing to compute ggml_cgraph with Metal
//
// This is a fully functional interface that extends ggml with GPU support for Apple devices.
// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
//
// How it works?
//
// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
//
// You only need to make sure that all memory buffers that you used during the graph creation
// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
// used during the graph evaluation to determine the arguments of the compute kernels.
//
// Synchronization between device and host memory (for example for input and output tensors)
// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
//
#pragma once
#include <stddef.h>
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
struct ggml_tensor;
struct ggml_cgraph;
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_metal_context;
struct ggml_metal_context * ggml_metal_init(void);
void ggml_metal_free(struct ggml_metal_context * ctx);
// creates a mapping between a host memory buffer and a device memory buffer
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
// - the mapping is used during computation to determine the arguments of the compute kernels
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
// - max_size specifies the maximum size of a tensor and is used to create shared views such
// that it is guaranteed that the tensor will fit in at least one of the views
//
bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx,
const char * name,
void * data,
size_t size,
size_t max_size);
// set data from host memory into the device
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// get data from the device into host memory
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
#ifdef __cplusplus
}
#endif

View File

@ -1,980 +0,0 @@
#import "ggml-metal.h"
#import "ggml.h"
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#ifdef GGML_METAL_NDEBUG
#define metal_printf(...)
#else
#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
#endif
#define UNUSED(x) (void)(x)
struct ggml_metal_buffer {
const char * name;
void * data;
size_t size;
id<MTLBuffer> metal;
};
struct ggml_metal_context {
float * logits;
id<MTLDevice> device;
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
// custom kernels
#define GGML_METAL_DECL_KERNEL(name) \
id<MTLFunction> function_##name; \
id<MTLComputePipelineState> pipeline_##name
GGML_METAL_DECL_KERNEL(add);
GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale);
GGML_METAL_DECL_KERNEL(silu);
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
#undef GGML_METAL_DECL_KERNEL
};
// MSL code
// TODO: move the contents here when ready
// for now it is easier to work in a separate file
static NSString * const msl_library_source = @"see metal.metal";
// Here to assist with NSBundle Path Hack
@interface GGMLMetalClass : NSObject
@end
@implementation GGMLMetalClass
@end
struct ggml_metal_context * ggml_metal_init(void) {
fprintf(stderr, "%s: allocating\n", __func__);
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
fprintf(stderr, "%s: using MPS\n", __func__);
} else {
fprintf(stderr, "%s: not using MPS\n", __func__);
GGML_ASSERT(false && "MPS not supported");
}
#if 0
// compile from source string and show compile log
{
NSError * error = nil;
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
}
#else
UNUSED(msl_library_source);
// read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
{
NSError * error = nil;
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
#ifdef GGML_QKK_64
MTLCompileOptions* options = [MTLCompileOptions new];
options.preprocessorMacros = @{ @"QK_K" : @(64) };
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
#else
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
#endif
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
}
#endif
// load kernels
{
#define GGML_METAL_ADD_KERNEL(name) \
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale);
GGML_METAL_ADD_KERNEL(silu);
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
#undef GGML_METAL_ADD_KERNEL
}
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
if (ctx->device.maxTransferRate != 0) {
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else {
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
}
return ctx;
}
void ggml_metal_free(struct ggml_metal_context * ctx) {
fprintf(stderr, "%s: deallocating\n", __func__);
for (int i = 0; i < ctx->n_buffers; ++i) {
[ctx->buffers[i].metal release];
}
free(ctx);
}
// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
//
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
const int64_t tsize = ggml_nbytes(t);
// find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs;
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
return ctx->buffers[i].metal;
}
}
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
return nil;
}
bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx,
const char * name,
void * data,
size_t size,
size_t max_size) {
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
fprintf(stderr, "%s: too many buffers\n", __func__);
return false;
}
if (data) {
// verify that the buffer does not overlap with any of the existing buffers
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
return false;
}
}
const size_t size_page = getpagesize();
size_t size_aligned = size;
if ((size_aligned % size_page) != 0) {
size_aligned += (size_page - (size_aligned % size_page));
}
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= ctx->device.maxBufferLength) {
ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = data;
ctx->buffers[ctx->n_buffers].size = size;
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
return false;
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
++ctx->n_buffers;
} else {
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
// one of the views
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
const size_t size_view = ctx->device.maxBufferLength;
for (size_t i = 0; i < size; i += size_step) {
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
return false;
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
if (i + size_step < size) {
fprintf(stderr, "\n");
}
++ctx->n_buffers;
}
}
fprintf(stderr, ", (%8.2f / %8.2f)",
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
} else {
fprintf(stderr, "\n");
}
}
return true;
}
void ggml_metal_set_tensor(
struct ggml_metal_context * ctx,
struct ggml_tensor * t) {
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
size_t offs;
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
}
void ggml_metal_get_tensor(
struct ggml_metal_context * ctx,
struct ggml_tensor * t) {
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
size_t offs;
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
}
void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__);
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
const int n_cb = gf->n_threads;
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
for (int i = 0; i < n_cb; ++i) {
command_buffers[i] = [ctx->queue commandBuffer];
// enqueue the command buffers in order to specify their execution order
[command_buffers[i] enqueue];
}
// TODO: is this the best way to start threads?
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
dispatch_async(queue, ^{
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = nil;
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
for (int i = node_start; i < node_end; ++i) {
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src0;
struct ggml_tensor * src1 = gf->nodes[i]->src1;
struct ggml_tensor * dst = gf->nodes[i];
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
const int64_t ne03 = src0 ? src0->ne[3] : 0;
const uint64_t nb00 = src0 ? src0->nb[0] : 0;
const uint64_t nb01 = src0 ? src0->nb[1] : 0;
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
const int64_t ne10 = src1 ? src1->ne[0] : 0;
const int64_t ne11 = src1 ? src1->ne[1] : 0;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
const int64_t ne0 = dst ? dst->ne[0] : 0;
const int64_t ne1 = dst ? dst->ne[1] : 0;
const int64_t ne2 = dst ? dst->ne[2] : 0;
const int64_t ne3 = dst ? dst->ne[3] : 0;
const uint64_t nb0 = dst ? dst->nb[0] : 0;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
//metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
//if (src0) {
// metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
// ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
switch (dst->op) {
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_add];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_mul];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SILU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int n_past = ((int32_t *)(src1->data))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL_MAT:
{
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12);
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
// we need to do ne02 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne02; ++i02) {
size_t offs_src0_cur = offs_src0 + i02*nb02;
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
int nth0 = 32;
int nth1 = 1;
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F16:
{
GGML_ASSERT(ne02 == ne12);
nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
{
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
}
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float eps = 1e-6f;
const int nth = 256;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float eps = 1e-5f;
const int nth = 256;
[encoder setComputePipelineState:ctx->pipeline_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ALIBI:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
GGML_ASSERT((src0t == GGML_TYPE_F32));
const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
const int n_head = ((int32_t *) src1->data)[1];
const float max_bias = ((float *) src1->data)[2];
if (__builtin_popcount(n_head) != 1) {
GGML_ASSERT(false && "only power-of-two n_head implemented");
}
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
const int nth = 32;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_past = ((int32_t *)(src1->data))[0];
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_CPY:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int nth = 32;
switch (src0t) {
case GGML_TYPE_F32:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
case GGML_TYPE_F16:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
}
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
[command_buffer commit];
});
}
// wait for all threads to finish
dispatch_barrier_sync(queue, ^{});
[command_buffers[n_cb - 1] waitUntilCompleted];
// check status of command buffers
// needed to detect if the device ran out-of-memory for example (#1881)
for (int i = 0; i < n_cb; i++) {
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) {
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false);
}
}
}

File diff suppressed because it is too large Load Diff

398
ggml-opencl.c Normal file
View File

@ -0,0 +1,398 @@
#include "ggml-opencl.h"
#define CL_TARGET_OPENCL_VERSION 110
#include <clblast_c.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "ggml.h"
#define MULTILINE_QUOTE(...) #__VA_ARGS__
const char * clblast_dequant = MULTILINE_QUOTE(
struct block_q4_0
{
float d;
uchar qs[16];
};
__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
const uint i = get_global_id(0) / 32;
const uint l = get_local_id(0);
const float d = blocks[i].d;
const uchar vi = blocks[i].qs[l];
const uint index = i*32 + l*2;
result[index + 0] = ((vi & 0xf) - 8)*d;
result[index + 1] = ((vi >> 4) - 8)*d;
}
struct block_q4_1
{
float d;
float m;
uchar qs[16];
};
__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
const uint i = get_global_id(0) / 32;
const uint l = get_local_id(0);
const float d = blocks[i].d;
const float m = blocks[i].m;
const uchar vi = blocks[i].qs[l];
const uint index = i*32 + l*2;
result[index + 0] = (vi & 0xf) * d + m;
result[index + 1] = (vi >> 4) * d + m;
}
struct block_q4_2
{
ushort d;
uchar qs[8];
};
__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
const uint i = get_global_id(0) / 16;
const uint l = get_local_id(0);
const float d = vload_half(0, (__global half*) &blocks[i].d);
const uchar vi = blocks[i].qs[l];
const uint index = i*16 + l*2;
result[index + 0] = ((vi & 0xf) - 8)*d;
result[index + 1] = ((vi >> 4) - 8)*d;
}
struct block_q5_0
{
float d;
uint qh;
uchar qs[16];
};
__kernel void dequantize_row_q5_0(__global struct block_q5_0* blocks, __global float* result) {
const uint i = get_global_id(0) / 32;
const uint l = get_local_id(0);
const float d = blocks[i].d;
const uchar vi = blocks[i].qs[l];
const uint l2 = l * 2;
const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
const uint index = i*32 + l2;
result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
result[index + 1] = (((vi >> 4) | vh1) - 16)*d;
}
struct block_q5_1
{
ushort d;
ushort m;
uint qh;
uchar qs[16];
};
__kernel void dequantize_row_q5_1(__global struct block_q5_1* blocks, __global float* result) {
const uint i = get_global_id(0) / 32;
const uint l = get_local_id(0);
const float d = vload_half(0, (__global half*) &blocks[i].d);
const float m = vload_half(0, (__global half*) &blocks[i].m);
const uchar vi = blocks[i].qs[l];
const uint l2 = l * 2;
const uchar vh0 = ((blocks[i].qh & (1 << (l2 + 0))) >> (l2 + 0)) << 4;
const uchar vh1 = ((blocks[i].qh & (1 << (l2 + 1))) >> (l2 + 1)) << 4;
const uint index = i*32 + l2;
result[index + 0] = ((vi & 0xf) | vh0)*d + m;
result[index + 1] = ((vi >> 4) | vh1)*d + m;
}
struct block_q8_0
{
float d;
char qs[32];
};
__kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
const uint i = get_global_id(0) / 32;
const uint l = get_local_id(0);
result[i*32 + l] = blocks[i].qs[l] * blocks[i].d;
}
);
#define CL_CHECK(err, name) \
do { \
cl_int err_ = (err); \
if (err_ != CL_SUCCESS) { \
fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \
exit(1); \
} \
} while (0)
#define QK5_0 32
typedef struct {
ggml_fp16_t d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
typedef struct {
float d; // delta
uint32_t qh; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} cl_block_q5_0;
static cl_platform_id platform;
static cl_device_id device;
static cl_context context;
static cl_command_queue queue;
static cl_program program;
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q5_0, kernel_q5_1, kernel_q8_0;
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) {
cl_program p;
char *program_log;
size_t program_size, log_size;
int err;
program_size = strlen(program_buffer);
p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
if(err < 0) {
fprintf(stderr, "OpenCL error creating program");
exit(1);
}
err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL);
if(err < 0) {
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
program_log = (char*) malloc(log_size + 1);
program_log[log_size] = '\0';
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
printf("%s\n", program_log);
free(program_log);
exit(1);
}
return p;
}
void ggml_cl_init(void) {
cl_int err = 0;
char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM");
char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE");
int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM));
int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE));
printf("\nInitializing CLBlast (First Run)...");
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
cl_uint num_platforms;
clGetPlatformIDs(0, NULL, &num_platforms);
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
clGetPlatformIDs(num_platforms, platforms, NULL);
platform = platforms[plat_num];
char platform_buffer[1024];
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
cl_uint num_devices;
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
device = devices[dev_num];
char device_buffer[1024];
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
CL_CHECK(err, "clCreateContext");
queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err);
CL_CHECK(err, "clCreateCommandQueue");
free(platforms);
free(devices);
program = build_program_from_source(context, device, clblast_dequant);
// Prepare dequantize kernels
kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err);
CL_CHECK(err, "clCreateKernel");
kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
CL_CHECK(err, "clCreateKernel");
kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
CL_CHECK(err, "clCreateKernel");
kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
CL_CHECK(err, "clCreateKernel");
kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
CL_CHECK(err, "clCreateKernel");
kernel_q8_0 = clCreateKernel(program, "dequantize_row_q8_0", &err);
CL_CHECK(err, "clCreateKernel");
}
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
if (req_size <= *cur_size) {
return;
}
// Reallocate buffer with enough space
if (*cur_size > 0) {
clReleaseMemObject(*buf);
}
cl_int err;
*buf = clCreateBuffer(context, flags, req_size, NULL, &err);
*cur_size = req_size;
CL_CHECK(err, "clCreateBuffer");
}
void ggml_cl_sgemm_wrapper(
const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b,
const int m, const int n, const int k,
const float alpha, const void *host_a, const int lda,
const float *host_b, const int ldb, const float beta,
float *host_c, const int ldc, const int btype) {
cl_int err = 0;
cl_kernel kernel;
size_t global = n * k, local, size_qb;
bool dequant;
cl_block_q5_0* cl_host_b;
switch (btype) {
case GGML_TYPE_F32:
dequant = false;
break;
case GGML_TYPE_Q4_0:
dequant = true;
kernel = kernel_q4_0;
local = 16;
size_qb = global * (sizeof(float) + local) / 32;
break;
case GGML_TYPE_Q4_1:
dequant = true;
kernel = kernel_q4_1;
local = 16;
size_qb = global * (sizeof(float) * 2 + local) / 32;
break;
case GGML_TYPE_Q4_2:
dequant = true;
kernel = kernel_q4_2;
local = 8;
size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
break;
case GGML_TYPE_Q5_0:
dequant = true;
kernel = kernel_q5_0;
local = 16;
// For some reason OpenCL seems to be incapable of working with structs of size 22.
// 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
// TODO Find the reason, fix and remove workaround.
const block_q5_0* b = (const block_q5_0*) host_b;
cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
for (size_t i = 0; i < global / 32; i++) {
cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
}
host_b = (const float*) cl_host_b;
size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
break;
case GGML_TYPE_Q5_1:
dequant = true;
kernel = kernel_q5_1;
local = 16;
size_qb = global * (sizeof(ggml_fp16_t) * 2 + sizeof(uint32_t) + local) / 32;
break;
case GGML_TYPE_Q8_0:
dequant = true;
kernel = kernel_q8_0;
local = 32;
size_qb = global * (sizeof(float) + local) / 32;
break;
default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
abort();
}
const size_t size_a = m * k * sizeof(float);
const size_t size_b = n * k * sizeof(float);
const size_t size_c = m * n * sizeof(float);
// Prepare buffers
ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a);
if (dequant) {
ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb);
}
ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b);
ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c);
cl_event ev_a, ev_qb, ev_b;
if (dequant) {
err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb);
err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b);
CL_CHECK(err, "clSetKernelArg");
err = clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb);
CL_CHECK(err, "clEnqueueWriteBuffer qb");
} else {
err = clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b);
CL_CHECK(err, "clEnqueueWriteBuffer b");
}
err = clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a);
CL_CHECK(err, "clEnqueueWriteBuffer a");
if (dequant) {
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b);
CL_CHECK(err, "clEnqueueNDRangeKernel");
clReleaseEvent(ev_qb);
}
clWaitForEvents(1, &ev_a);
clWaitForEvents(1, &ev_b);
clReleaseEvent(ev_a);
clReleaseEvent(ev_b);
cl_event ev_sgemm;
CLBlastStatusCode status = CLBlastSgemm((CLBlastLayout)order,
(CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b,
m, n, k,
alpha,
cl_buffer_a, 0, lda,
cl_buffer_b, 0, ldb,
beta,
cl_buffer_c, 0, ldc,
&queue, &ev_sgemm);
if (status != CLBlastSuccess) {
fprintf(stderr, "Error: CLBlast SGEMM %d\n", status);
abort();
}
cl_event ev_c;
clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c);
// Wait for completion
clWaitForEvents(1, &ev_c);
clReleaseEvent(ev_sgemm);
clReleaseEvent(ev_c);
if (btype == GGML_TYPE_Q5_0) {
free((void*) cl_host_b);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,24 +1,23 @@
#pragma once
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
void ggml_cl_init(void);
void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
enum ggml_blas_order {
GGML_BLAS_ORDER_ROW_MAJOR = 101,
GGML_BLAS_ORDER_COLUMN_MAJOR = 102,
};
void * ggml_cl_host_malloc(size_t size);
void ggml_cl_host_free(void * ptr);
enum ggml_blas_op {
GGML_BLAS_OP_N = 111,
GGML_BLAS_OP_T = 112,
GGML_BLAS_OP_C = 113,
};
void ggml_cl_free_data(const struct ggml_tensor* tensor);
void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype);
#ifdef __cplusplus
}

11444
ggml.c

File diff suppressed because it is too large Load Diff

674
ggml.h

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,15 @@
## Whisper model files in custom ggml format
The [original Whisper PyTorch models provided by OpenAI](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L17-L27)
are converted to custom `ggml` format in order to be able to load them in C/C++.
Conversion is performed using the [convert-pt-to-ggml.py](convert-pt-to-ggml.py) script.
You can either obtain the original models and generate the `ggml` files yourself using the conversion script,
or you can use the [download-ggml-model.sh](download-ggml-model.sh) script to download the already converted models.
Currently, they are hosted on the following locations:
have been converted to custom `ggml` format in order to be able to load them in C/C++. The conversion has been performed
using the [convert-pt-to-ggml.py](convert-pt-to-ggml.py) script. You can either obtain the original models and generate
the `ggml` files yourself using the conversion script, or you can use the [download-ggml-model.sh](download-ggml-model.sh)
script to download the already converted models. Currently, they are hosted on the following locations:
- https://huggingface.co/ggerganov/whisper.cpp
- https://ggml.ggerganov.com
Sample download:
Sample usage:
```java
$ ./download-ggml-model.sh base.en
@ -23,16 +21,6 @@ You can now use it like this:
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
```
To convert the files yourself, use the convert-pt-to-ggml.py script. Here is an example usage.
The original PyTorch files are assumed to have been downloaded into ~/.cache/whisper
Change `~/path/to/repo/whisper/` to the location for your copy of the Whisper source:
```
mkdir models/whisper-medium
python models/convert-pt-to-ggml.py ~/.cache/whisper/medium.pt ~/path/to/repo/whisper/ ./models/whisper-medium
mv ./models/whisper-medium/ggml-model.bin models/ggml-medium.bin
rmdir models/whisper-medium
```
A third option to obtain the model files is to download them from Hugging Face:
https://huggingface.co/ggerganov/whisper.cpp/tree/main
@ -70,7 +58,7 @@ git clone https://github.com/openai/whisper
git clone https://github.com/ggerganov/whisper.cpp
# clone HF fine-tuned model (this is just an example)
git clone https://huggingface.co/openai/whisper-medium
git clone https://huggingface.co/openai/whisper-base.en
# convert the model to ggml
python3 ./whisper.cpp/models/convert-h5-to-ggml.py ./whisper-medium/ ./whisper .

Some files were not shown because too many files have changed in this diff Show More