Compare commits

..

91 Commits

Author SHA1 Message Date
9962371f71 release : v1.5.3 2024-01-03 19:36:33 +02:00
993acb5d41 swift : update Package.swift to use ggml as package dependency (#1701)
* updates Package.swift to use ggml as dependency

* cleans up the Package.swift file by removing redundant source files

* updates ggml url src to ggerganov
2024-01-03 19:30:26 +02:00
a3d0aa73d1 ggml : add error handling to graph_compute (#1714) 2024-01-03 15:39:43 +02:00
14c57952f7 cuda : simplify expression
Co-authored-by: slaren <slarengh@gmail.com>
2024-01-03 14:43:51 +02:00
6c369d6788 cuda : mark I16 and I32 ops as unsupported
ggml-ci
2024-01-03 14:43:51 +02:00
4cdd9aad9b metal : add kernel_get_rows_i32
ggml-ci
2024-01-03 14:43:51 +02:00
f38c057503 metal : optimize ggml_mul_mat_id (faster Mixtral PP) (llama/4725)
* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id
2024-01-03 14:43:51 +02:00
1e5544b39b metal : enable shader debugging (cmake option) (llama/4705)
* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

ggml-ci
2024-01-03 14:43:51 +02:00
d5673af79f ggml : add ggml_vdotq_s32 alias (llama/4715)
ggml-ci
2024-01-03 14:43:51 +02:00
a28dacec65 CUDA: fixed tensor cores not being used on RDNA3 (llama/4697) 2024-01-03 14:43:51 +02:00
dbe29d4e33 ggml : add ggml_cpu_has_avx_vnni() (llama/4589)
* feat: add avx_vnni based on intel documents

* ggml: add avx vnni based on intel document

* llama: add avx vnni information display

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* docs: add more details about using oneMKL and oneAPI for intel processors

* Update ggml.c

Fix indentation upgate

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-01-03 14:43:51 +02:00
fe3a67c546 CUDA: fix tensor core logic for Pascal and HIP (llama/4682) 2024-01-03 14:43:51 +02:00
b138ff2be3 cuda: fix vmm oom issue on NVIDIA AGX Orin (llama/4687)
Signed-off-by: hydai <hydai@secondstate.io>
2024-01-03 14:43:51 +02:00
cf6f1e4181 ggml : extend ggml_get_rows, ggml_repeat, ggml_concat (ggml/639)
* add more int ops

* ggml_compute_forward_dup_bytes

* add tests

* PR comments

* tests : minor indentations

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-01-03 14:43:51 +02:00
620a223814 scripts : fix sync order + metal sed 2024-01-03 14:43:51 +02:00
f39f9690ec examples : fix WASM Stack Overflow (#1713)
Fix for problem:

"""
RuntimeError: Aborted(Stack overflow! Stack cookie has been overwritten at 0x12be2b10, expected hex dwords 0x89BACDFE and 0x2135467, but received 0x00000000 0x00000000)
"""

That appears when executing the WASM example with the newer versions.
2024-01-02 16:50:04 +00:00
f9ca90256b docker : fix the publishing of the CUDA Docker image (#1704) 2023-12-30 23:12:31 +02:00
2623640cd6 scripts : do not sync commits from this repo 2023-12-29 15:03:08 +02:00
d87de61ae6 ci : build with CLBlast + ggml-opencl use GGML_API (#1576)
* Build with CLBlast

* Declare GGML_API

After rebasing, examples/talk-llama failed:

"D:\a\whisper.cpp\whisper.cpp\build\ALL_BUILD.vcxproj" (build target) (1) ->
"D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj" (default target) (14) ->
(Link target) ->
  llama.obj : error LNK2019: unresolved external symbol ggml_cl_free_data referenced in function "public: __cdecl llama_model::~llama_model(void)" (??1llama_model@@QEAA@XZ) [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
  llama.obj : error LNK2019: unresolved external symbol ggml_cl_transform_tensor referenced in function "public: void __cdecl llama_model_loader::load_all_data(struct ggml_context *,void (__cdecl*)(float,void *),void *,struct llama_mlock *)" (?load_all_data@llama_model_loader@@QEAAXPEAUggml_context@@P6AXMPEAX@Z1PEAUllama_mlock@@@Z) [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
  D:\a\whisper.cpp\whisper.cpp\build\bin\Release\talk-llama.exe : fatal error LNK1120: 2 unresolved externals [D:\a\whisper.cpp\whisper.cpp\build\examples\talk-llama\talk-llama.vcxproj]
2023-12-29 12:23:27 +02:00
f5f485f899 whisper : replace tensor->n_dims with ggml_n_dims(tensor) (#1694) 2023-12-29 11:38:35 +02:00
e77b27c331 sync : ggml (VMM, sync-ggml-am, dotprod ARM fixes, CUDA fixes) (#1691)
* scripts : add sync-ggml-am.sh

* sync : ggml (VMM, ARM dot prod fix, etc.)

* build : fix CUDA build

* ggml : fix some mul mat cases + add tests for src1 F16

dbd02958fa
2023-12-29 11:30:47 +02:00
a5cc3dc8a2 download : fix large q5 model name (#1695)
fixed typo in large-v3-q5-0 model name to match HF link
2023-12-29 11:14:32 +02:00
37a709f655 whisper : Replace WHISPER_PRINT_DEBUG with WHISPER_LOG_DEBUG (#1681) 2023-12-23 12:02:58 +00:00
3a5302108d sync : ggml (ggml_scale, ggml_row_size, etc.) (#1677)
* sync : ggml

* sync : llama.cpp

* talk-llama : fix obsolete param

* ggml-alloc : fix ggml_tallocr_is_own

* talk.wasm : update to new ggml

* ggml : fix type punning in ggml_scale

* ggml : cuda jetson + arm quants warnings
2023-12-22 17:53:39 +02:00
d2ee117a0a docker : Dockerize whisper.cpp (#1674)
* build: add dockerfile for ci

* ci: add action to build/push docker image

* fix: lowercase repository to fix ci

* ci: update cuBLAS flag

* build: install curl and ffmped in image

* docs: add docker section

* fix: improve args check when download model
2023-12-22 11:16:02 +00:00
db8ccdb850 CI : Add coverage for talk-llama when WHISPER_CUBLAS=1 (#1672) 2023-12-21 22:39:46 +00:00
d2419030b0 examples : Revert CMakeLists.txt for talk-llama (#1669) 2023-12-21 22:48:52 +02:00
8986690c2a cmake : set default CUDA architectures (#1667) 2023-12-21 15:44:04 +02:00
9286d3f584 bench.py : add different large models (#1655)
Amend different large v1,v2,v3 models to benchmark.
2023-12-19 12:40:14 +02:00
940de9dbe9 wchess : update README.md 2023-12-14 22:00:47 +02:00
88112c8afb release : v1.5.2 2023-12-14 17:56:39 +02:00
375585c07c wchess : update readme 2023-12-14 17:51:14 +02:00
fd99ece8e3 wchess : whisper assisted chess (#1595)
* wchess: whisper assisted chess

* wchess: fix allowed moves in check

* wchess: touchstart, touchend events

* wchess: css, disabled button

* wchess : html touches

* wchess : minor fixes and code style

* wchess : bump encoder context to 1280

* wchess : index.html

* wchess : fix CI warnings

* wchess : add array header

* wchess : build static library

* wchess : display grammar

* wchess : update UX

* wchess : add comment

* wchess : add README

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-12-14 15:58:26 +02:00
8171e621fc sync : ggml (Metal fixes, new ops, tests) (#1633)
* sync : ggml (Metal fixes, new ops, tests)

* cuda : fix bin bcast when src1 and dst have different types
2023-12-13 21:55:03 +02:00
ec03661b20 cmake : target windows 8 or above for prefetchVirtualMemory in llama-talk (#1617)
Since we use prefetchVirtualMemory we specify we target win 8 or above, otherwise other compilers will refuse to use the prefetchVirtualMemory api, (I understand you are loading it dynamically but the header definition has this limitation)
2023-12-12 11:35:00 +00:00
6335933a5b cmake : Fix bug in httplib.h for mingw (#1615)
Fix bug in httlib.h for mingw, please see https://github.com/yhirose/cpp-httplib/issues/1669
2023-12-10 17:47:52 +00:00
885b5563d0 metal : fix ggml_metal_log vargs (#1606) 2023-12-08 13:50:50 +02:00
9521ba6801 whisper.objc : disable timestamps for real-time transcription 2023-12-08 13:43:37 +02:00
29511d33c7 whisper : more debug messages + fix fallback logic 2023-12-08 13:43:12 +02:00
7bc4d22337 metal : fix soft_max kernel src1 argument (#1602) 2023-12-08 13:39:32 +02:00
afce6fa113 sync : ggml (new ops, new backend, etc) (#1602)
* sync : ggml (new ops, new backend, etc)

* whisper : remove obsolete broadcasting code

* ggml : remove backend self-registers + fix ggml_concat + n_task logic

* metal : fix assert

* metal : print resource path

* whisper : fix bug if metal init fails
2023-12-07 22:27:19 +02:00
3163090d89 server : pass max-len argument to the server (#1574)
This commit fixes the missing parameter binding for max-len between the input
arguments and wparams.
2023-12-05 23:01:45 +02:00
f0efd0202d ios : Remove #if arch(arm) check for using Metal (#1561) 2023-12-05 01:14:26 +00:00
3c28d1a571 ggml : Fix 32-bit compiler warning (#1575)
Warning about %lu on 32-bit targets. Updated to %zu.
2023-12-03 14:15:28 +00:00
e369243ebd ggml : re-enable blas for src0 != F32 (#1583) 2023-12-01 23:57:52 +02:00
a0ec3fac54 Server : Add support for .vtt format to Whisper server (#1578)
- The code comes from examples/main
- The output mimetype is set to text/vtt

Example usage:
```shell
curl 127.0.0.1:8080/inference \
-H "Content-Type: multipart/form-data" \
-F file="@samples/jfk.wav" \
-F temperature="0.2" \
-F response-format="vtt"
```
2023-11-30 23:44:26 +00:00
6559b538e5 server : backport .srt output format (#1565)
This commit adds a support of .srt format to Whisper server. The code is
effectively backported from examples/main. The output mimetype is set to
application/x-subrip as per https://en.wikipedia.org/wiki/SubRip.

Example usage:

  curl 127.0.0.1:8080/inference \
    -H "Content-Type: multipart/form-data" \
    -F file="@<file-path>" \
    -F temperature="0.2" \
    -F response-format="srt"
2023-11-28 15:42:58 +02:00
73d5005880 cmake : install required ggml.h header (#1568) 2023-11-28 15:41:49 +02:00
6b094b6dfe server : set default CORS headers to allow all (#1567) 2023-11-28 11:55:20 +02:00
641f2f4282 readme : update help (#1560) 2023-11-27 12:04:08 +02:00
bfacd9f8ce CI : Add CUDA 11.8.0 support (#1554)
* try to fix cublas build in CI

* add multiple cuda-toolkit version

* Update build.yml

* Disable CUDA-toolkit 10.2.89
2023-11-27 12:03:16 +02:00
f52e74d4dc CI : Rectify the Clang-Related workflow issues (#1551)
* fix bugs in workflow

* fix missing clang in workflow

* Update build.yml
2023-11-27 11:35:37 +02:00
23c21e92eb server : automatically convert audio on the server (#1539)
* server : automatically convert audio on the server

* server : remove rebundant comments

* server : automatic conversion refactor

* server : update server readme

* server : remove unnecessary comments and tabs

* server : put back remove calling

* server : apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server : check ffmpeg before the server lunch

* server : fix indentation

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server : fix function typo calling

* server : fix function typo calling

* server : add warning in readme

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-11-27 11:28:34 +02:00
447d49530c whisper : remove trailing whitespaces 2023-11-24 13:13:21 +02:00
9d6ebd877c release : v1.5.1 2023-11-24 12:41:55 +02:00
0ba365f958 metal : add backend function to check device family support (#1547) 2023-11-24 12:37:08 +02:00
010c8ec3ab cuda : sync some minor stuff from llama.cpp (#1548) 2023-11-24 12:36:21 +02:00
ffdb5c4735 whisper : fix typo 2023-11-24 09:45:10 +02:00
a5881d619c server : add --print-realtime param (#1541)
* server : add --print-realtime param

* Fix duplicate realtime output
2023-11-24 09:35:02 +02:00
34f70b3a56 whisper : add whisper_lang_str_full (#1546)
* Update whisper.h

add whisper_lang_fullstr to retrieve the full language name

* Update whisper.cpp

add whisper_lang_fullstr to return the full language name

* fullstr -> str_full

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-11-24 09:33:13 +02:00
8328d1900f fix(server): typo in temperature parameter (#1545)
Also fixed another typo in comments.
2023-11-23 20:59:36 +02:00
d2bd5f0bdc metal : fix build (#1544) 2023-11-23 20:20:53 +02:00
34209a37a2 readme : add server example 2023-11-23 17:20:33 +02:00
180e062eda go : fixed Makefile for MacOS ARM 64 (#1530)
* Fixed Makefile for MacOS ARM 64 based on https://github.com/ggerganov/whisper.cpp/issues/1344 + proper ggml-metal env var setting

* conditional to fix broken non-macos compilation

* spaces -> tab

* make : fix whitespaces

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-11-22 18:08:11 +02:00
5c7be85fdc Change temp file name for server application (#1535)
Avoid issue of removing file if it exists in the current working
directory
2023-11-22 09:23:36 +01:00
146169ec38 bench : pass memcpy threads from cli 2023-11-21 22:27:22 +02:00
9befab5ab9 bench : multi-thread memcpy (#1534) 2023-11-21 22:07:30 +02:00
9ac88f2b57 Close file after writing in server application (#1533)
Fix of mistake leaving file open while reading it again as wav
2023-11-21 20:36:10 +01:00
46f5b6cb08 server : add video to readme 2023-11-21 17:30:43 +02:00
eff3570f78 server : add a REST Whisper server example with OAI-like API (#1380)
* Add first draft of server

* Added json support and base funcs for server.cpp

* Add more user input via api-request

also some clean up

* Add reqest params and load post function

Also some general clean up

* Remove unused function

* Add readme

* Add exception handlers

* Update examples/server/server.cpp

* make : add server target

* Add magic curl syntax

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-11-20 21:40:24 +02:00
fa19bc4195 whisper : update example in whisper.h (#1529)
update the example in the header, previous examples deprecated.
2023-11-20 20:52:27 +02:00
a01b2e0971 sdl : fix audio callback (#1523) 2023-11-20 13:16:38 +02:00
8159a9ab99 whisper : reuse whisper_decode_with_state (#1521) 2023-11-20 13:16:11 +02:00
7516d9c16d ci : redistribute CUDA DLLs (#1522)
see https://docs.nvidia.com/cuda/eula/index.html#attachment-a
2023-11-19 12:43:22 +02:00
46cc26d1b9 whisper : fix with_state methods to use the correct state (#1519)
Co-authored-by: Sandro Hanea <sandrohanea@microsoft.com>
2023-11-19 11:25:30 +02:00
f784f9fa12 whisper : fix overriding the audio context 2023-11-19 10:32:32 +02:00
ca23f8ee6d cuda : assert ggml_add sources to be contiguous 2023-11-19 10:32:08 +02:00
e2f0eba2d4 ios : sync submodule 2023-11-17 10:42:04 +02:00
d4353e48f7 sync : ggml (ggml-alloc + linker + gguf fixes) (#1501) 2023-11-17 10:00:07 +02:00
bebf0da983 quantize : add support for K-quant types 2023-11-16 16:18:24 +02:00
848e54f3ad bench : fix memcpy bench size 2023-11-16 10:59:32 +02:00
7883d1cae4 talk-llama : improve quote and backtick handling (#1364)
* ISSUE-1329: replace " with ' so it doesn't try to execute code in backticks.

* Typo

* Update to keep possessives in the output

Closes the ' then puts a ' in quotes then reopens the ' to escape the ' characters.
2023-11-16 10:34:05 +02:00
ccc85b4ff8 talk-llama : enable GPU by default 2023-11-15 21:33:00 +02:00
c7606b47df models : add info about distilled models 2023-11-15 21:10:13 +02:00
d38af151a1 release : v1.5.0 2023-11-15 21:02:52 +02:00
94267df08e bench-all : add distil models 2023-11-15 20:49:12 +02:00
8713c67133 js : latest whisper.js 2023-11-15 20:10:16 +02:00
57a60639bb bench-all : indentations 2023-11-15 20:01:15 +02:00
bfbaa4dce5 whisper : make large version explicit + fix data size units (#1493) 2023-11-15 19:42:25 +02:00
1d79e78402 java : fix test (#1492) 2023-11-15 17:42:53 +02:00
b6c5f49b78 whisper : add batched decoding (#1486)
* whisper : add whisper_batch

* whisper : move kv_self to whisper_state

* whisper : full batched decoding support

* whisper : fix memory leak in whisper_batch

* whisper : fix mem leak again + remove oboslete function

* whisper : clear kv cache when using whisper_decode API

* whisper : speed-up sampling

* whisper : fix decoders initializer

* bench : add batch size 5 bench

* whisper : add comment about the KV cache size

* whisper : add check for max number of decoders

* whisper : avoid starting sampling threads with bs=1

* whisper : enable beam-search by default

* cuda : sync llama.cpp fixes
2023-11-15 16:12:52 +02:00
107 changed files with 51240 additions and 4346 deletions

View File

@ -0,0 +1,38 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=12.3.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the CUDA runtime image
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
WORKDIR /app
# Unless otherwise specified, we make a fat build.
ARG CUDA_DOCKER_ARCH=all
# Set nvcc architecture
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV WHISPER_CUBLAS=1
RUN apt-get update && \
apt-get install -y build-essential \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# Ref: https://stackoverflow.com/a/53464012
ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY .. .
RUN make
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENTRYPOINT [ "bash", "-c" ]

19
.devops/main.Dockerfile Normal file
View File

@ -0,0 +1,19 @@
FROM ubuntu:22.04 AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
RUN make
FROM ubuntu:22.04 AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app
ENTRYPOINT [ "bash", "-c" ]

View File

@ -25,6 +25,7 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential libsdl2-dev
make
@ -86,6 +87,7 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }}
@ -113,8 +115,9 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake libsdl2-dev
apt install -y clang build-essential cmake libsdl2-dev
cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang
make
ctest -L gh --output-on-failure'
@ -140,6 +143,7 @@ jobs:
docker run --platform ${{ matrix.arch }} --rm \
-v ${{ github.workspace }}:/workspace \
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential cmake
cmake . -DCMAKE_BUILD_TYPE=Debug -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON
@ -162,7 +166,7 @@ jobs:
s2arc: x64
jnaPath: win32-x86-64
- sdl2: ON
s2ver: 2.26.0
s2ver: 2.28.5
steps:
- name: Clone
@ -217,13 +221,16 @@ jobs:
sdl2: [ON]
include:
- arch: Win32
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x86.zip
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x86.zip
s2arc: x86
clblast: OFF
- arch: x64
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip
obzip: https://github.com/OpenMathLib/OpenBLAS/releases/download/v0.3.25/OpenBLAS-0.3.25-x64.zip
s2arc: x64
clblast: ON
clver: 1.6.1
- sdl2: ON
s2ver: 2.26.0
s2ver: 2.28.5
steps:
- name: Clone
@ -248,6 +255,18 @@ jobs:
7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Install OpenCL
if: matrix.clblast == 'ON'
run: vcpkg.exe --triplet=${{ matrix.arch }}-windows install opencl
- name: Fetch CLBlast and set CLBlast_DIR
if: matrix.clblast == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO clblast.zip https://github.com/CNugteren/CLBlast/releases/download/${{ matrix.clver }}/CLBlast-${{ matrix.clver }}-windows-x64.zip
7z x clblast.zip
7z x CLBlast-${{ matrix.clver }}-windows-x64.7z
echo "CLBlast_DIR=$env:GITHUB_WORKSPACE/CLBlast-${{ matrix.clver }}-windows-x64/lib/cmake/CLBlast" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
@ -255,6 +274,7 @@ jobs:
-DWHISPER_OPENBLAS=${{ matrix.blas }}
-DCMAKE_LIBRARY_PATH="$env:OPENBLAS_PATH/lib"
-DWHISPER_SDL2=${{ matrix.sdl2 }}
-DWHISPER_CLBLAST=${{ matrix.clblast }}
- name: Build
run: |
@ -269,11 +289,15 @@ jobs:
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Copy clblast.dll
if: matrix.clblast == 'ON'
run: copy "$env:CLBlast_DIR/../../clblast.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.blas == 'ON' && matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-blas-bin-${{ matrix.arch }}
name: whisper-blas${{ matrix.clblast == 'ON' && '-clblast' || ''}}-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
windows-cublas:
@ -285,11 +309,12 @@ jobs:
arch: [x64]
cublas: [ON]
sdl2: [ON]
cuda-toolkit: [12.2.0, 11.8.0]
include:
- arch: x64
s2arc: x64
- sdl2: ON
s2ver: 2.26.0
s2ver: 2.28.5
steps:
- name: Clone
@ -300,7 +325,9 @@ jobs:
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.10
uses: Jimver/cuda-toolkit@v0.2.11
with:
cuda: '${{ matrix.cuda-toolkit }}'
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
@ -313,12 +340,20 @@ jobs:
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=1
-DWHISPER_CUBLAS=${{ matrix.cublas }}
-DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build
- name: Build ${{ matrix.cuda-toolkit }}
run: |
cd ./build
msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }}
cmake --build . --config ${{ matrix.build }}
- name: Copy CUDA DLLs
run: >
Copy-Item -PassThru
-Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
-Include cudart64_*,cublas64_*,cublasLt64_*
-Destination build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
@ -328,7 +363,7 @@ jobs:
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v1
with:
name: whisper-cublas-bin-${{ matrix.arch }}
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
emscripten:

57
.github/workflows/docker.yml vendored Normal file
View File

@ -0,0 +1,57 @@
name: Publish Docker image
on:
pull_request:
push:
branches:
- master
jobs:
push_to_registry:
name: Push Docker image to Docker Hub
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
env:
COMMIT_SHA: ${{ github.sha }}
strategy:
matrix:
config:
- { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64,linux/arm64" }
- { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" }
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push Docker image (versioned)
if: github.event_name == 'push'
uses: docker/build-push-action@v5
with:
context: .
push: true
platforms: ${{ matrix.config.platforms }}
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
file: ${{ matrix.config.dockerfile }}
- name: Build and push Docker image (tagged)
uses: docker/build-push-action@v4
with:
context: .
push: ${{ github.event_name == 'push' }}
platforms: ${{ matrix.config.platforms }}
tags: "ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}"
file: ${{ matrix.config.dockerfile }}

1
.gitignore vendored
View File

@ -31,6 +31,7 @@ build-sanitize-thread/
/talk-llama
/bench
/quantize
/server
/lsp
arm_neon.h

View File

@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.5)
project(whisper.cpp VERSION 1.4.3)
project(whisper.cpp VERSION 1.5.3)
# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@ -218,11 +218,17 @@ if (WHISPER_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS)
if (WHISPER_STATIC)
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
if (WIN32)
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
else ()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
endif()
else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
else()
message(FATAL_ERROR "cuBLAS not found")
endif()
@ -338,8 +344,8 @@ else()
endif()
else()
if (EMSCRIPTEN)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -s TOTAL_STACK=5242880")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -s TOTAL_STACK=5242880")
else()
if(NOT WHISPER_NO_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx")
@ -521,7 +527,13 @@ endif()
if (GGML_SOURCES_CUDA)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
# Only configure gmml CUDA architectures is not globally set
if (NOT DEFINED GGML_CUDA_ARCHITECTURES)
# Not overriden by user, so set defaults
set(GGML_CUDA_ARCHITECTURES 52 61 70)
endif()
message(STATUS "GGML Configuring CUDA architectures ${GGML_CUDA_ARCHITECTURES}")
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES ${GGML_CUDA_ARCHITECTURES})
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
endif()
@ -533,7 +545,7 @@ target_compile_definitions(${TARGET} PUBLIC
${WHISPER_EXTRA_FLAGS}
)
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "ggml.h;whisper.h")
include(GNUInstallDirs)

View File

@ -1,4 +1,4 @@
default: main bench quantize
default: main bench quantize server
ifndef UNAME_S
UNAME_S := $(shell uname -s)
@ -206,7 +206,7 @@ ifdef WHISPER_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib
WHISPER_OBJ += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
@ -338,7 +338,7 @@ libwhisper.so: $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so $(WHISPER_OBJ) $(LDFLAGS)
clean:
rm -f *.o main stream command talk talk-llama bench quantize lsp libwhisper.a libwhisper.so
rm -f *.o main stream command talk talk-llama bench quantize server lsp libwhisper.a libwhisper.so
#
# Examples
@ -359,6 +359,9 @@ bench: examples/bench/bench.cpp $(WHISPER_OBJ)
quantize: examples/quantize/quantize.cpp $(WHISPER_OBJ) $(SRC_COMMON)
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o quantize $(LDFLAGS)
server: examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/server/server.cpp $(SRC_COMMON) $(WHISPER_OBJ) -o server $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ)
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
@ -418,9 +421,9 @@ samples:
.PHONY: medium
.PHONY: large-v1
.PHONY: large-v2
.PHONY: large
.PHONY: large-v3
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large: main
tiny.en tiny base.en base small.en small medium.en medium large-v1 large-v2 large-v3: main
bash ./models/download-ggml-model.sh $@
@echo ""
@echo "==============================================="

View File

@ -2,41 +2,26 @@
import PackageDescription
#if arch(arm) || arch(arm64)
let platforms: [SupportedPlatform]? = [
.macOS(.v12),
.iOS(.v14),
.watchOS(.v4),
.tvOS(.v14)
]
let exclude: [String] = []
let resources: [Resource] = [
.process("ggml-metal.metal")
]
let additionalSources: [String] = ["ggml-metal.m"]
let additionalSettings: [CSetting] = [
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
]
#else
let platforms: [SupportedPlatform]? = nil
let exclude: [String] = ["ggml-metal.metal"]
let resources: [Resource] = []
let additionalSources: [String] = []
let additionalSettings: [CSetting] = []
#endif
let package = Package(
name: "whisper",
platforms: platforms,
platforms: [
.macOS(.v12),
.iOS(.v14),
.watchOS(.v4),
.tvOS(.v14)
],
products: [
.library(name: "whisper", targets: ["whisper"]),
],
dependencies: [
.package(url: "https://github.com/ggerganov/ggml.git", .branch("master"))
],
targets: [
.target(
name: "whisper",
dependencies: ["ggml"],
path: ".",
exclude: exclude + [
exclude: [
"bindings",
"cmake",
"coreml",
@ -51,23 +36,20 @@ let package = Package(
"Makefile"
],
sources: [
"ggml.c",
"whisper.cpp",
"ggml-alloc.c",
"ggml-backend.c",
"ggml-quants.c"
] + additionalSources,
resources: resources,
],
publicHeadersPath: "spm-headers",
cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.define("GGML_USE_ACCELERATE")
.define("GGML_USE_ACCELERATE"),
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
] + additionalSettings,
],
linkerSettings: [
.linkedFramework("Accelerate")
]

View File

@ -6,7 +6,7 @@
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Beta: [v1.4.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.3) / Stable: [v1.2.1](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.2.1) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Stable: [v1.5.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.5.3) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -33,6 +33,7 @@ Supported platforms:
- [x] [WebAssembly](examples/whisper.wasm)
- [x] Windows ([MSVC](https://github.com/ggerganov/whisper.cpp/blob/master/.github/workflows/build.yml#L117-L144) and [MinGW](https://github.com/ggerganov/whisper.cpp/issues/168)]
- [x] [Raspberry Pi](https://github.com/ggerganov/whisper.cpp/discussions/166)
- [x] [docker](https://github.com/ggerganov/whisper.cpp/pkgs/container/whisper.cpp)
The entire high-level implementation of the model is contained in [whisper.h](whisper.h) and [whisper.cpp](whisper.cpp).
The rest of the code is part of the [ggml](https://github.com/ggerganov/ggml) machine learning library.
@ -110,8 +111,8 @@ options:
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
@ -128,6 +129,7 @@ options:
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
@ -139,7 +141,8 @@ options:
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of token
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
bash ./models/download-ggml-model.sh base.en
@ -231,18 +234,18 @@ make medium.en
make medium
make large-v1
make large-v2
make large
make large-v3
```
## Memory usage
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~125 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| base | 142 MB | ~210 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| small | 466 MB | ~600 MB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| medium | 1.5 GB | ~1.7 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| large | 2.9 GB | ~3.3 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
| Model | Disk | Mem |
| --- | --- | --- |
| tiny | 75 MiB | ~273 MB |
| base | 142 MiB | ~388 MB |
| small | 466 MiB | ~852 MB |
| medium | 1.5 GiB | ~2.1 GB |
| large | 2.9 GiB | ~3.9 GB |
## Quantization
@ -446,6 +449,36 @@ make clean
WHISPER_OPENBLAS=1 make -j
```
## Docker
### Prerequisites
* Docker must be installed and running on your system.
* Create a folder to store big models & intermediate files (ex. /whisper/models)
### Images
We have two Docker images available for this project:
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
### Usage
```shell
# download model and persist it in a local folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "./models/download-ggml-model.sh base /models"
# transcribe an audio file
docker run -it --rm \
-v path/to/models:/models \
-v path/to/audios:/audios \
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
# transcribe an audio file in samples folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
```
## Limitations
- Inference only
@ -768,6 +801,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
@ -777,6 +811,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [server](examples/server) | | HTTP transcription server with OAI-like API |
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)

View File

@ -1,9 +1,26 @@
ifndef UNAME_S
UNAME_S := $(shell uname -s)
endif
ifndef UNAME_P
UNAME_P := $(shell uname -p)
endif
ifndef UNAME_M
UNAME_M := $(shell uname -m)
endif
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
BUILD_DIR := build
MODELS_DIR := models
EXAMPLES_DIR := $(wildcard examples/*)
INCLUDE_PATH := $(abspath ../..)
LIBRARY_PATH := $(abspath ../..)
ifeq ($(UNAME_S),Darwin)
EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit
endif
all: clean whisper examples
whisper: mkdir
@ -11,8 +28,13 @@ whisper: mkdir
@${MAKE} -C ../.. libwhisper.a
test: model-small whisper modtidy
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go test -ldflags "-extldflags '$(EXT_LDFLAGS)'" -v ./pkg/whisper/...
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v .
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go test -v ./pkg/whisper/...
endif
examples: $(EXAMPLES_DIR)
@ -21,7 +43,11 @@ model-small: mkdir examples/go-model-download
$(EXAMPLES_DIR): mkdir whisper modtidy
@echo Build example $(notdir $@)
ifeq ($(UNAME_S),Darwin)
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} GGML_METAL_PATH_RESOURCES=${GGML_METAL_PATH_RESOURCES} go build ${BUILD_FLAGS} -ldflags "-extldflags '$(EXT_LDFLAGS)'" -o ${BUILD_DIR}/$(notdir $@) ./$@
else
@C_INCLUDE_PATH=${INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} go build ${BUILD_FLAGS} -o ${BUILD_DIR}/$(notdir $@) ./$@
endif
mkdir:
@echo Mkdir ${BUILD_DIR}

View File

@ -24,7 +24,7 @@ const (
var (
// The models which will be downloaded, if no model is specified as an argument
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large"}
modelNames = []string{"ggml-tiny.en", "ggml-tiny", "ggml-base.en", "ggml-base", "ggml-small.en", "ggml-small", "ggml-medium.en", "ggml-medium", "ggml-large-v1", "ggml-large-v2", "ggml-large-v3"}
)
var (

View File

@ -45,7 +45,7 @@ class WhisperCppTest {
assertEquals(16384, params.n_max_text_ctx);
assertFalse(params.translate);
assertEquals(0.01f, params.thold_pt);
assertEquals(2, params.beam_search.beam_size);
assertEquals(5, params.beam_search.beam_size);
assertEquals(-1.0f, params.beam_search.patience);
}
@ -58,7 +58,7 @@ class WhisperCppTest {
assertEquals(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY.ordinal(), params.strategy);
assertNotEquals(0, params.n_threads);
assertEquals(16384, params.n_max_text_ctx);
assertEquals(2, params.greedy.best_of);
assertEquals(5, params.greedy.best_of);
}
@Test

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.4.3",
"version": "1.5.3",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

File diff suppressed because one or more lines are too long

View File

@ -70,7 +70,7 @@ extern "C" {
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph without a plan
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
// check if the backend supports an operation
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);

View File

@ -156,8 +156,8 @@ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_
backend->iface.graph_plan_compute(backend, plan);
}
void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
backend->iface.graph_compute(backend, cgraph);
bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
return backend->iface.graph_compute(backend, cgraph);
}
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {

View File

@ -52,7 +52,7 @@ extern "C" {
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
// tensor copy between different backends

View File

@ -14,6 +14,10 @@ if (WHISPER_SDL2)
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
endif()
if (WHISPER_CLBLAST)
find_package(CLBlast REQUIRED)
endif()
# common
set(TARGET common)
@ -65,6 +69,7 @@ elseif(CMAKE_JS_VERSION)
else()
add_subdirectory(main)
add_subdirectory(stream)
add_subdirectory(server)
add_subdirectory(command)
add_subdirectory(bench)
add_subdirectory(quantize)
@ -72,3 +77,5 @@ else()
add_subdirectory(talk-llama)
add_subdirectory(lsp)
endif()
add_subdirectory(wchess)

View File

@ -9,6 +9,11 @@ static const std::map<std::string, enum ggml_ftype> GGML_FTYPE_MAP = {
{"q5_0", GGML_FTYPE_MOSTLY_Q5_0},
{"q5_1", GGML_FTYPE_MOSTLY_Q5_1},
{"q8_0", GGML_FTYPE_MOSTLY_Q8_0},
{"q2_k", GGML_FTYPE_MOSTLY_Q2_K},
{"q3_k", GGML_FTYPE_MOSTLY_Q3_K},
{"q4_k", GGML_FTYPE_MOSTLY_Q4_K},
{"q5_k", GGML_FTYPE_MOSTLY_Q5_K},
{"q6_k", GGML_FTYPE_MOSTLY_Q6_K},
};
void ggml_print_ftypes(FILE * fp) {
@ -48,15 +53,15 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_MOSTLY_Q5_0: qtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: qtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: qtype = GGML_TYPE_Q8_0; break;
case GGML_FTYPE_MOSTLY_Q2_K: qtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: qtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: qtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: qtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: qtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_UNKNOWN:
case GGML_FTYPE_ALL_F32:
case GGML_FTYPE_MOSTLY_F16:
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16:
case GGML_FTYPE_MOSTLY_Q2_K:
case GGML_FTYPE_MOSTLY_Q3_K:
case GGML_FTYPE_MOSTLY_Q4_K:
case GGML_FTYPE_MOSTLY_Q5_K:
case GGML_FTYPE_MOSTLY_Q6_K:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
@ -167,24 +172,17 @@ bool ggml_common_quantize_0(
switch ((ggml_type) ttype) {
case GGML_TYPE_Q4_0:
{
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q4_1:
{
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_0:
{
cur_size = ggml_quantize_q5_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q5_1:
{
cur_size = ggml_quantize_q5_1(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
} break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
{
cur_size = ggml_quantize_q8_0(data_f32.data(), work.data(), nelements, ne[0], hist_cur.data());
cur_size = ggml_quantize_chunk((ggml_type) ttype, data_f32.data(), work.data(), 0, nelements, hist_cur.data());
} break;
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@ -192,11 +190,6 @@ bool ggml_common_quantize_0(
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_K:
case GGML_TYPE_COUNT:
{

View File

@ -139,10 +139,13 @@ void audio_async::callback(uint8_t * stream, int len) {
return;
}
const size_t n_samples = len / sizeof(float);
size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
if (n_samples > m_audio.size()) {
n_samples = m_audio.size();
stream += (len - (n_samples * sizeof(float)));
}
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
@ -153,7 +156,7 @@ void audio_async::callback(uint8_t * stream, int len) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
memcpy(&m_audio[0], stream + n0 * sizeof(float), (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();

View File

@ -41,7 +41,6 @@ private:
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};

View File

@ -22,6 +22,7 @@ var printTextarea = (function() {
async function clearCache() {
if (confirm('Are you sure you want to clear the cache?\nAll the models will be downloaded again.')) {
indexedDB.deleteDatabase(dbName);
location.reload();
}
}

View File

@ -48,7 +48,7 @@ if [ -n "$3" ]; then
fi
# Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
# list available models
function list_models {

View File

@ -17,28 +17,37 @@ options:
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-bs N, --beam-size N [5 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-su, --speed-up [false ] speed up audio by x2 (reduced accuracy)
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-olrc, --output-lrc [false ] output result in a lrc file
-owts, --output-words [false ] output script for generating karaoke video
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [true ] do not print timestamps
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
```

View File

@ -165,8 +165,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);

View File

@ -0,0 +1,12 @@
set(TARGET server)
add_executable(${TARGET} server.cpp httplib.h json.hpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${CMAKE_THREAD_LIBS_INIT})
# Check if the compiler is MinGW
if(MINGW)
# Link the necessary libraries for SSL and Winsock
target_link_libraries(${TARGET} PRIVATE -lcrypt32 -lssl -lcrypto -lws2_32)
endif()

68
examples/server/README.md Normal file
View File

@ -0,0 +1,68 @@
# whisper.cpp http server
Simple http server. WAV Files are passed to the inference model via http requests.
https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-afe5e4594b8f
## Usage
```
./server -h
usage: ./bin/server [options]
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [2 ] number of best candidates to keep
-bs N, --beam-size N [-1 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pr, --print-realtime [false ] print output in realtime
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
--host HOST, [127.0.0.1] Hostname/ip-adress for the server
--port PORT, [8080 ] Port number for the server
--convert, [false ] Convert audio to WAV, requires ffmpeg on the server
```
> [!WARNING]
> **Do not run the server example with administrative privileges and ensure it's operated in a sandbox environment, especially since it involves risky operations like accepting user file uploads and using ffmpeg for format conversions. Always validate and sanitize inputs to guard against potential security threats.**
## request examples
**/inference**
```
curl 127.0.0.1:8080/inference \
-H "Content-Type: multipart/form-data" \
-F file="@<file-path>" \
-F temperature="0.2" \
-F response-format="json"
```
**/load**
```
curl 127.0.0.1:8080/load \
-H "Content-Type: multipart/form-data" \
-F model="<path-to-model-file>"
```

9262
examples/server/httplib.h Normal file

File diff suppressed because it is too large Load Diff

24596
examples/server/json.hpp Normal file

File diff suppressed because it is too large Load Diff

811
examples/server/server.cpp Normal file
View File

@ -0,0 +1,811 @@
#include "common.h"
#include "whisper.h"
#include "httplib.h"
#include "json.hpp"
#include <cmath>
#include <fstream>
#include <cstdio>
#include <string>
#include <thread>
#include <vector>
#include <cstring>
#include <sstream>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
using namespace httplib;
using json = nlohmann::json;
namespace {
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green.
const std::vector<std::string> k_colors = {
"\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
};
// output formats
const std::string json_format = "json";
const std::string text_format = "text";
const std::string srt_format = "srt";
const std::string vjson_format = "verbose_json";
const std::string vtt_format = "vtt";
struct server_params
{
std::string hostname = "127.0.0.1";
std::string public_path = "examples/server/public";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
bool ffmpeg_converter = false;
};
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 2;
int32_t beam_size = -1;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
float logprob_thold = -1.00f;
float userdef_temp = 0.20f;
bool speed_up = false;
bool debug_mode = false;
bool translate = false;
bool detect_language = false;
bool diarize = false;
bool tinydiarize = false;
bool split_on_word = false;
bool no_fallback = false;
bool print_special = false;
bool print_colors = false;
bool print_realtime = false;
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
std::string language = "en";
std::string prompt = "";
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin";
std::string response_format = json_format;
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::string openvino_encode_device = "CPU";
};
// 500 -> 00:05.000
// 6000 -> 01:00.000
std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
bool is_file_exist(const char *fileName)
{
std::ifstream infile(fileName);
return infile.good();
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params,
const server_params& sparams) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] \n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
// server params
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, "\n");
}
bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
else if ( arg == "--public") { sparams.public_path = argv[++i]; }
else if ( arg == "--convert") { sparams.ffmpeg_converter = true; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
}
return true;
}
struct whisper_print_user_data {
const whisper_params * params;
const std::vector<std::vector<float>> * pcmf32s;
int progress_prev;
};
void check_ffmpeg_availibility() {
int result = system("ffmpeg -version");
if (result == 0) {
std::cout << "ffmpeg is available." << std::endl;
} else {
// ffmpeg is not available
std::cout << "ffmpeg is not found. Please ensure that ffmpeg is installed ";
std::cout << "and that its executable is included in your system's PATH. ";
exit(0);
}
}
bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) {
std::ostringstream cmd_stream;
std::string converted_filename_temp = temp_filename + "_temp.wav";
cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1";
std::string cmd = cmd_stream.str();
int status = std::system(cmd.c_str());
if (status != 0) {
error_resp = "{\"error\":\"FFmpeg conversion failed.\"}";
return false;
}
// Remove the original file
if (remove(temp_filename.c_str()) != 0) {
error_resp = "{\"error\":\"Failed to remove the original file.\"}";
return false;
}
// Rename the temporary file to match the original filename
if (rename(converted_filename_temp.c_str(), temp_filename.c_str()) != 0) {
error_resp = "{\"error\":\"Failed to rename the temporary file.\"}";
return false;
}
return true;
}
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
std::string speaker = "";
const int64_t n_samples = pcmf32s[0].size();
const int64_t is0 = timestamp_to_sample(t0, n_samples);
const int64_t is1 = timestamp_to_sample(t1, n_samples);
double energy0 = 0.0f;
double energy1 = 0.0f;
for (int64_t j = is0; j < is1; j++) {
energy0 += fabs(pcmf32s[0][j]);
energy1 += fabs(pcmf32s[1][j]);
}
if (energy0 > 1.1*energy1) {
speaker = "0";
} else if (energy1 > 1.1*energy0) {
speaker = "1";
} else {
speaker = "?";
}
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
if (!id_only) {
speaker.insert(0, "(speaker ");
speaker.append(")");
}
return speaker;
}
void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
if (progress >= *progress_prev + progress_step) {
*progress_prev += progress_step;
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress);
}
}
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
const int n_segments = whisper_full_n_segments(ctx);
std::string speaker = "";
int64_t t0 = 0;
int64_t t1 = 0;
// print the last n_new segments
const int s0 = n_segments - n_new;
if (s0 == 0) {
printf("\n");
}
for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
if (params.print_colors) {
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s%s", speaker.c_str(), text);
}
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
fflush(stdout);
}
}
std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
std::stringstream result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
result << speaker << text << "\n";
}
return result.str();
}
void get_req_parameters(const Request & req, whisper_params & params)
{
// user model configu.has_fileion
if (req.has_file("offset-t"))
{
params.offset_t_ms = std::stoi(req.get_file_value("offset-t").content);
}
if (req.has_file("offset-n"))
{
params.offset_n = std::stoi(req.get_file_value("offset-n").content);
}
if (req.has_file("duration"))
{
params.duration_ms = std::stoi(req.get_file_value("duration").content);
}
if (req.has_file("max-context"))
{
params.max_context = std::stoi(req.get_file_value("max-context").content);
}
if (req.has_file("prompt"))
{
params.prompt = req.get_file_value("prompt").content;
}
if (req.has_file("response-format"))
{
params.response_format = req.get_file_value("response-format").content;
}
if (req.has_file("temperature"))
{
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
}
}
} // namespace
int main(int argc, char ** argv) {
whisper_params params;
server_params sparams;
std::mutex whisper_mutex;
if (whisper_params_parse(argc, argv, params, sparams) == false) {
whisper_print_usage(argc, argv, params, sparams);
return 1;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
if (sparams.ffmpeg_converter) {
check_ffmpeg_availibility();
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
Server svr;
svr.set_default_headers({{"Server", "whisper.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
std::string const default_content = "<html>hello</html>";
// this is only called if no index.html is found in the public --path
svr.Get("/", [&default_content](const Request &, Response &res){
res.set_content(default_content, "text/html");
return false;
});
svr.Post("/inference", [&](const Request &req, Response &res){
// acquire whisper model mutex lock
whisper_mutex.lock();
// first check user requested fields of the request
if (!req.has_file("file"))
{
fprintf(stderr, "error: no 'file' field in the request\n");
const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
auto audio_file = req.get_file_value("file");
// check non-required fields
get_req_parameters(req, params);
std::string filename{audio_file.filename};
printf("Received request: %s\n", filename.c_str());
// audio arrays
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// write to temporary file
const std::string temp_filename = "whisper_server_temp_file.wav";
std::ofstream temp_file{temp_filename, std::ios::binary};
temp_file << audio_file.content;
temp_file.close();
// if file is not wav, convert to wav
if (sparams.ffmpeg_converter) {
std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}";
const bool is_converted = convert_to_wav(temp_filename, error_resp);
if (!is_converted) {
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}
// read wav content into pcmf32
if (!::read_wav(temp_filename, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str());
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
res.set_content(error_resp, "application/json");
std::remove(temp_filename.c_str());
whisper_mutex.unlock();
return;
}
// remove temp file
std::remove(temp_filename.c_str());
printf("Successfully loaded %s\n", filename.c_str());
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, filename.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// run the inference
{
printf("Running whisper.cpp inference on %s\n", filename.c_str());
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.temperature_inc = params.userdef_temp;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
// this callback is called on each new segment
if (params.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
if (wparams.print_progress) {
wparams.progress_callback = whisper_print_progress_callback;
wparams.progress_callback_user_data = &user_data;
}
// examples for abort mechanism
// in examples below, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
// the callback is called before every computation - if it returns true, the computation is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.abort_callback = [](void * user_data) {
bool is_aborted = *(bool*)user_data;
return is_aborted;
};
wparams.abort_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
}
// return results to user
if (params.response_format == text_format)
{
std::string results = output_str(ctx, params, pcmf32s);
res.set_content(results.c_str(), "text/html");
}
else if (params.response_format == srt_format)
{
std::stringstream ss;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
}
ss << i + 1 + params.offset_n << "\n";
ss << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "application/x-subrip");
} else if (params.response_format == vtt_format) {
std::stringstream ss;
ss << "WEBVTT\n\n";
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
std::string speaker = "";
if (params.diarize && pcmf32s.size() == 2)
{
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
speaker.insert(0, "<v Speaker");
speaker.append(">");
}
ss << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
ss << speaker << text << "\n\n";
}
res.set_content(ss.str(), "text/vtt");
}
// TODO add more output formats
else
{
std::string results = output_str(ctx, params, pcmf32s);
json jres = json{
{"text", results}
};
res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
}
// return whisper model mutex lock
whisper_mutex.unlock();
});
svr.Post("/load", [&](const Request &req, Response &res){
whisper_mutex.lock();
if (!req.has_file("model"))
{
fprintf(stderr, "error: no 'model' field in the request\n");
const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
std::string model = req.get_file_value("model").content;
if (!is_file_exist(model.c_str()))
{
fprintf(stderr, "error: 'model': %s not found!\n", model.c_str());
const std::string error_resp = "{\"error\":\"model not found!\"}";
res.set_content(error_resp, "application/json");
whisper_mutex.unlock();
return;
}
// clean up
whisper_free(ctx);
// whisper init
ctx = whisper_init_from_file_with_params(model.c_str(), cparams);
// TODO perhaps load prior model here instead of exit
if (ctx == nullptr) {
fprintf(stderr, "error: model init failed, no model loaded must exit\n");
exit(1);
}
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);
const std::string success = "Load was successful!";
res.set_content(success, "application/text");
// check if the model is in the file system
whisper_mutex.unlock();
});
svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try {
std::rethrow_exception(std::move(ep));
} catch (std::exception &e) {
snprintf(buf, sizeof(buf), fmt, e.what());
} catch (...) {
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain");
res.status = 500;
});
svr.set_error_handler([](const Request &, Response &res) {
if (res.status == 400) {
res.set_content("Invalid request", "text/plain");
} else if (res.status != 500) {
res.set_content("File Not Found", "text/plain");
res.status = 404;
}
});
// set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
sparams.hostname.c_str(), sparams.port);
return 1;
}
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
// to make it ctrl+clickable:
printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
if (!svr.listen_after_bind())
{
return 1;
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -1,25 +1,18 @@
if (WHISPER_SDL2)
# talk-llama
set(TARGET talk-llama)
#add_executable(${TARGET} talk-llama.cpp llama.cpp)
#target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
#target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
add_executable(${TARGET} talk-llama.cpp llama.cpp)
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET}
talk-llama.cpp
llama.cpp
../common.cpp
../common-sdl.cpp
../../ggml.c
../../ggml-alloc.c
../../ggml-backend.c
../../ggml-quants.c
../../whisper.cpp)
if (WHISPER_CLBLAST)
set(CLBLAST_LIBNAME clblast)
endif ()
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CLBLAST_LIBNAME} ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
if(WIN32)
# It requires Windows 8.1 or later for PrefetchVirtualMemory
target_compile_definitions(${TARGET} PRIVATE -D_WIN32_WINNT=0x0602)
endif()
include(DefaultTargetOptions)
endif ()

File diff suppressed because it is too large Load Diff

View File

@ -39,10 +39,11 @@
#define LLAMA_MAX_RNG_STATE (64*1024)
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 2
#define LLAMA_SESSION_VERSION 3
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@ -126,7 +127,7 @@ extern "C" {
bool sorted;
} llama_token_data_array;
typedef void (*llama_progress_callback)(float progress, void *ctx);
typedef bool (*llama_progress_callback)(float progress, void *ctx);
// Input data for llama_decode
// A llama_batch object can contain input about one or many sequences
@ -158,16 +159,38 @@ extern "C" {
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_INT,
LLAMA_KV_OVERRIDE_FLOAT,
LLAMA_KV_OVERRIDE_BOOL,
};
struct llama_model_kv_override {
char key[128];
enum llama_model_kv_override_type tag;
union {
int64_t int_value;
double float_value;
bool bool_value;
};
};
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
// called with a progress value between 0 and 1, pass NULL to disable
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted.
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
// override key-value pairs of the model meta data
const struct llama_model_kv_override * kv_overrides;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
@ -185,17 +208,20 @@ extern "C" {
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
enum ggml_type type_k; // data type for K cache
enum ggml_type type_v; // data type for V cache
// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
};
// model quantization parameters
@ -290,7 +316,9 @@ extern "C" {
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
// TODO: become more consistent with returned int types across the API
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
@ -301,6 +329,23 @@ extern "C" {
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - GGUF array values are not supported by these functions
// Get metadata value as a string by key name
LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
// Get the number of metadata key/value pairs
LLAMA_API int llama_model_meta_count(const struct llama_model * model);
// Get metadata key name by index
LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
// Get metadata value as a string by index
LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
// Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
@ -344,9 +389,60 @@ extern "C" {
// KV cache
//
// Returns the number of tokens in the KV cache
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
// Information associated with an individual cell in the KV cache view.
struct llama_kv_cache_view_cell {
// The position for this cell. Takes KV cache shifts into account.
// May be negative if the cell is not populated.
llama_pos pos;
};
// An updateable view of the KV cache.
struct llama_kv_cache_view {
// Number of KV cache cells. This will be the same as the context size.
int32_t n_cells;
// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int32_t n_max_seq;
// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int32_t token_count;
// Number of populated cache cells.
int32_t used_cells;
// Maximum contiguous empty slots in the cache.
int32_t max_contiguous;
// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int32_t max_contiguous_idx;
// Information for an individual cell.
struct llama_kv_cache_view_cell * cells;
// The sequences for each cell. There will be n_max_seq items per cell.
llama_seq_id * cells_sequences;
};
// Create an empty KV cache view. (use only for debugging purposes)
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
// Free a KV cache view. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache
LLAMA_API void llama_kv_cache_clear(
@ -517,6 +613,12 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int llama_add_bos_token(const struct llama_model * model);
// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int llama_add_eos_token(const struct llama_model * model);
// codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle

View File

@ -53,7 +53,7 @@ struct whisper_params {
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
int32_t n_gpu_layers = 0;
int32_t n_gpu_layers = 999;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
@ -136,7 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7s] number of layers to store in VRAM\n", params.n_gpu_layers);
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
@ -282,7 +282,6 @@ int main(int argc, char ** argv) {
// tune these to your liking
lcparams.n_ctx = 2048;
lcparams.seed = 1;
lcparams.f16_kv = true;
lcparams.n_threads = params.n_threads;
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
@ -686,8 +685,8 @@ int main(int argc, char ** argv) {
}
}
text_to_speak = ::replace(text_to_speak, "\"", "");
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
text_to_speak = ::replace(text_to_speak, "'", "'\"'\"'");
int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + text_to_speak + "'").c_str());
if (ret != 0) {
fprintf(stderr, "%s: failed to speak\n", __func__);
}

View File

@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
@ -524,8 +524,7 @@ bool gpt2_eval(
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
1.0f/sqrt(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]

View File

@ -155,33 +155,33 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
@ -525,8 +525,7 @@ bool gpt2_eval(
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
);
1.0f/sqrt(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]

View File

@ -21,7 +21,7 @@ help()
echo "Usage: ./twitch.sh -s [step] -m [model] -t [threads] [url]"
echo "options:"
echo "-s Step in seconds (default is $step)."
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large' (default is '$model')."
echo "-m Choose model, options are: 'tiny.en' 'tiny' 'base.en' 'base' 'small.en' 'small' 'medium.en' 'medium' 'large-v1' 'large-v2' 'large-v3' (default is '$model')."
echo "-t Number of threads to use."
echo "-h Print this help page."
echo

View File

@ -0,0 +1,9 @@
set(CMAKE_CXX_STANDARD 11)
add_subdirectory(libwchess)
if (EMSCRIPTEN)
add_subdirectory(wchess.wasm)
else()
add_subdirectory(wchess.cmd)
endif()

45
examples/wchess/README.md Normal file
View File

@ -0,0 +1,45 @@
# wchess
Voice-controlled chess using Whisper
Online demo: https://whisper.ggerganov.com/wchess/
https://github.com/ggerganov/whisper.cpp/assets/1991296/c2b2f03c-9684-49f3-8106-357d2d4e67fa
## Command-line tool
```bash
mkdir build && cd build
cmake -DWHISPER_SDL2=1 ..
make -j
./bin/wchess -m ../models/ggml-base.en.bin
Move: start
a b c d e f g h
r n b q k b n r 8
p p p p p p p p 7
. * . * . * . * 6
* . * . * . * . 5
. * . * . * . * 4
* . * . * . * . 3
P P P P P P P P 2
R N B Q K B N R 1
White's turn
[(l)isten/(p)ause/(q)uit]:
```
## TODO
- Fix bugs in the chess moves logic
- Improve web-browser audio capture - sometimes it does not record the voice properly
- Add support for more languages by making the generated grammar string multilingual
- Explore ways to improve the dynamic grammar to be narrower
PRs welcome!
## Thanks
- [chessboardjs](https://chessboardjs.com) for the neat chessboard JS library used in this demo

View File

@ -0,0 +1,19 @@
add_library(wchess-core STATIC
WChess.cpp
WChess.h
Chessboard.cpp
Chessboard.h
)
target_link_libraries(wchess-core
PUBLIC
whisper
common
)
target_include_directories(wchess-core
PUBLIC
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
)
# add_executable(test-chessboard test-chessboard.cpp Chessboard.cpp)

View File

@ -0,0 +1,803 @@
#include "Chessboard.h"
#include <array>
#include <vector>
#include <algorithm>
#include <cstring>
#include <set>
#include <list>
#include <chrono>
namespace {
constexpr std::array<const char*, 64> positions = {
"a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1",
"a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2",
"a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3",
"a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4",
"a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5",
"a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6",
"a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7",
"a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8",
};
constexpr char INVALID_POS = positions.size();
constexpr int R = 0; // rank index
constexpr int F = 1; // file index
#define FILE (c[F] - '1')
#define RANK (c[R] - 'a')
constexpr char operator ""_P(const char * c, size_t size) {
return size < 2 || RANK < 0 || RANK > 7 ||
FILE < 0 || FILE > 7 ? INVALID_POS : FILE * 8 + RANK;
}
#undef FILE
#undef RANK
struct sview {
const char * ptr = nullptr;
size_t size = 0;
sview() = default;
sview(const char * p, size_t s) : ptr(p), size(s) {}
sview(const std::string& s) : ptr(s.data()), size(s.size()) {}
size_t find(char del, size_t pos) {
while (pos < size && ptr[pos] != del) ++pos;
return pos < size ? pos : std::string::npos;
}
};
std::vector<sview> split(sview str, char del) {
std::vector<sview> res;
size_t cur = 0;
size_t last = 0;
while (cur != std::string::npos) {
if (str.ptr[last] == ' ') {
++last;
continue;
}
cur = str.find(del, last);
size_t len = cur == std::string::npos ? str.size - last : cur - last;
res.emplace_back(str.ptr + last, len);
last = cur + 1;
}
return res;
}
char strToPos(sview str) {
return operator ""_P(str.ptr, str.size);
}
constexpr std::array<const char*, 6> pieceNames = {
"pawn", "knight", "bishop", "rook", "queen", "king",
};
static constexpr std::array<char, 6> blackShort = {
'p', 'n', 'b', 'r', 'q', 'k',
};
static constexpr std::array<char, 6> whiteShort = {
'P', 'N', 'B', 'R', 'Q', 'K',
};
char strToType(sview str) {
auto it = std::find_if(pieceNames.begin(), pieceNames.end(), [str] (const char* name) { return strncmp(name, str.ptr, str.size) == 0; });
return it != pieceNames.end() ? it - pieceNames.begin() : pieceNames.size();
}
// directions
using Direction = std::array<char, 2>;
constexpr Direction N = {(char) 0, (char) 1};
constexpr Direction NNE = {(char) 1, (char) 2};
constexpr Direction NE = {(char) 1, (char) 1};
constexpr Direction ENE = {(char) 2, (char) 1};
constexpr Direction E = {(char) 1, (char) 0};
constexpr Direction ESE = {(char) 2, (char) -1};
constexpr Direction SE = {(char) 1, (char) -1};
constexpr Direction SSE = {(char) 1, (char) -2};
constexpr Direction S = {(char) 0, (char) -1};
constexpr Direction SSW = {(char) -1, (char) -2};
constexpr Direction SW = {(char) -1, (char) -1};
constexpr Direction WSW = {(char) -2, (char) -1};
constexpr Direction W = {(char) -1, (char) 0};
constexpr Direction WNW = {(char) -2, (char) 1};
constexpr Direction NW = {(char) -1, (char) 1};
constexpr Direction NNW = {(char) -1, (char) 2};
char makeStep(char pos, const Direction& d) {
char next[2] = { char(positions[pos][R] + d[R]) , char(positions[pos][F] + d[F]) };
return strToPos(sview{next, sizeof(next)});
}
template<class Modifier>
char traverse(char pos, const Direction& d, const Modifier& m, int count = 8) {
while (--count >= 0) {
pos = makeStep(pos, d);
if (pos == INVALID_POS || m(pos)) break;
}
return pos;
}
Direction normalize(const Direction& distance) {
//return {char((distance[R] > 0) - (distance[R] < 0)), char((distance[F] > 0) - (distance[F] < 0))};
const int drp = distance[R] > 0 ? 1 : 0;
const int drn = distance[R] < 0 ? 1 : 0;
const int dfp = distance[F] > 0 ? 1 : 0;
const int dfn = distance[F] < 0 ? 1 : 0;
return {char(drp - drn), char(dfp - dfn)};
}
struct Pin {
Direction d;
Piece* pinner;
Piece* pinned;
};
using Pins = std::list<Pin>;
using Board = std::array<Piece*, 64>;
std::vector<Direction> filter(const Direction& pin, std::initializer_list<Direction> directions) {
if (pin[R] == 0 && pin[F] == 0) return directions;
std::vector<Direction> result;
for (auto& d : directions) {
if ((d[R] == pin[R] || d[R] == -pin[R]) && (d[F] == pin[F] || d[F] == -pin[F])) result.push_back(d);
}
return result;
}
}
class Piece {
public:
enum Types : char {
Pawn,
Knight,
Bishop,
Rook,
Queen,
King,
//
NUM_PIECES
};
enum Colors : char {
White,
Black,
};
const char* name() const;
char initial() const;
Types type() const { return m_type; }
Colors color() const { return m_color; }
char pos() const { return m_pos; }
void setPos(char pos) {
m_pos = pos;
invalidate();
}
const char* coord() const;
const std::set<char>& allowed() const { return m_allowed; }
bool canReach(char pos) const;
virtual bool movePattern(char pos) const = 0;
void take();
virtual void reinit(const State& state) = 0;
void invalidate();
protected:
Piece(Types type, Colors color, char pos, std::set<char> allowed)
: m_type(type), m_color(color), m_pos(pos), m_allowed(std::move(allowed)) {}
Piece(const Piece&) = delete;
~Piece() = default;
const Types m_type;
const Colors m_color;
char m_pos;
std::set<char> m_allowed;
bool m_update = false;
};
struct Pawn : public Piece {
Pawn(Colors color, char pos, std::set<char> next) : Piece(Types::Pawn, color, pos, std::move(next)) {}
bool is_first_move() const {
return m_color ? coord()[F] == '7' : coord()[F] == '2';
}
virtual bool movePattern(char pos) const override {
if (m_pos == INVALID_POS) return false;
auto cur = coord();
auto next = positions[pos];
Direction distance = {char(next[R] - cur[R]), char(next[F] - cur[F])};
char forward = m_color ? -1 : 1;
return (forward == distance[F] && distance[R] * distance[R] <= 1)
|| (is_first_move() && 2 * forward == distance[F] && distance[R] == 0);
}
virtual void reinit(const State& state) override;
};
struct Knight : public Piece {
Knight(Colors color, char pos, std::set<char> next) : Piece(Types::Knight, color, pos, std::move(next)) {}
virtual bool movePattern(char pos) const override {
if (m_pos == INVALID_POS) return false;
auto cur = coord();
auto next = positions[pos];
Direction diff = {char(next[R] - cur[R]), char(next[F] - cur[F])};
return diff[R]*diff[R] + diff[F]*diff[F] == 5;
}
virtual void reinit(const State& state) override;
};
struct Bishop : public Piece {
Bishop(Colors color, char pos) : Piece(Types::Bishop, color, pos, {}) {}
virtual bool movePattern(char pos) const override {
if (m_pos == INVALID_POS) return false;
auto cur = coord();
auto next = positions[pos];
return cur[R] - cur[F] == next[R] - next[F] || cur[R] + cur[F] == next[R] + next[F];
}
virtual void reinit(const State& state) override;
};
struct Rook : public Piece {
Rook(Colors color, char pos) : Piece(Types::Rook, color, pos, {}) {}
virtual bool movePattern(char pos) const override {
if (m_pos == INVALID_POS) return false;
auto cur = coord();
auto next = positions[pos];
return cur[R] == next[R] || cur[F] == next[F];
}
virtual void reinit(const State& state) override;
};
struct Queen : public Piece {
Queen(Colors color, char pos) : Piece(Types::Queen, color, pos, {}) {}
virtual bool movePattern(char pos) const override {
if (m_pos == INVALID_POS) return false;
auto cur = coord();
auto next = positions[pos];
return cur[R] == next[R] || cur[F] == next[F] || cur[R] - cur[F] == next[R] - next[F] || cur[R] + cur[F] == next[R] + next[F];
}
virtual void reinit(const State& state) override;
};
struct King : public Piece {
King(Colors color, char pos) : Piece(Types::King, color, pos, {}) {}
virtual bool movePattern(char pos) const override {
if (m_pos == INVALID_POS) return false;
auto cur = coord();
auto next = positions[pos];
Direction diff = {char(next[R] - cur[R]), char(next[F] - cur[F])};
return diff[R]*diff[R] + diff[F]*diff[F] <= 2;
}
virtual void reinit(const State& state) override;
};
struct PieceSet {
Piece* begin() { return &p1; }
Piece* end() { return &r2 + 1; }
const Piece* begin() const { return &p1; }
const Piece* end() const { return &r2 + 1; }
Piece& operator[](int i) { return *(begin() + i); }
const Piece& operator[](int i) const { return *(begin() + i); }
Pawn p1;
Pawn p2;
Pawn p3;
Pawn p4;
Pawn p5;
Pawn p6;
Pawn p7;
Pawn p8;
Rook r1;
Knight n1;
Bishop b1;
Queen q;
King k;
Bishop b2;
Knight n2;
Rook r2;
};
struct State {
State();
PieceSet blacks;
PieceSet whites;
Board board;
Pins blackPins;
Pins whitePins;
};
Direction findPin(const Piece& piece, const State& state) {
auto& pins = piece.color() ? state.blackPins : state.whitePins;
auto it = std::find_if(pins.begin(), pins.end(), [&] (const Pin& pin) { return pin.pinned == &piece; });
if (it != pins.end()) return it->d;
return {0, 0};
}
struct Find {
Find(const Board& board) : m_board(board) {}
bool operator() (char pos) const { return m_board[pos]; }
const Board& m_board;
};
struct Add {
Add(const Board& board, std::set<char>& moves, Piece::Colors color) : m_board(board), m_moves(moves), m_color(color) {}
bool operator() (char pos) const {
if (!m_board[pos] || m_board[pos]->color() != m_color) m_moves.insert(pos);
return m_board[pos];
}
const Board& m_board;
std::set<char>& m_moves;
Piece::Colors m_color;
};
void Pawn::reinit(const State& state) {
if (m_pos == INVALID_POS) return;
if (!m_update) return;
m_update = false;
m_allowed.clear();
auto pin = findPin(*this, state);
auto & left = m_color ? SW : NW;
auto & right = m_color ? SE : NE;
for (auto& direction : filter(pin, { left, right })) {
auto pos = makeStep(m_pos, direction);
if (pos != INVALID_POS && state.board[pos] && state.board[pos]->color() != m_color) m_allowed.insert(pos);
}
auto & forward = m_color ? S : N;
if (!filter(pin, {forward}).empty()) {
traverse(m_pos, forward, [&] (char pos) {
if (!state.board[pos]) m_allowed.insert(pos);
return state.board[pos] || !is_first_move();
}, 2);
}
}
void Knight::reinit(const State& state) {
if (m_pos == INVALID_POS) return;
if (!m_update) return;
m_update = false;
m_allowed.clear();
auto pin = findPin(*this, state);
if (pin[R] != 0 || pin[F] != 0) return;
for (auto& direction : { NNE, ENE, ESE, SSE, SSW, WSW, WNW, NNW }) {
auto pos = makeStep(m_pos, direction);
if (pos != INVALID_POS && (!state.board[pos] || state.board[pos]->color() != m_color)) m_allowed.insert(pos);
}
}
void Bishop::reinit(const State& state) {
if (m_pos == INVALID_POS) return;
if (!m_update) return;
m_update = false;
m_allowed.clear();
auto pin = findPin(*this, state);
for (auto& direction : filter(pin, { NE, SE, SW, NW })) {
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
}
}
void Rook::reinit(const State& state) {
if (m_pos == INVALID_POS) return;
if (!m_update) return;
m_update = false;
m_allowed.clear();
auto pin = findPin(*this, state);
for (auto& direction : filter(pin, { N, E, S, W })) {
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
}
}
void Queen::reinit(const State& state) {
if (m_pos == INVALID_POS) return;
if (!m_update) return;
m_update = false;
m_allowed.clear();
auto pin = findPin(*this, state);
for (auto& direction : filter(pin, { N, NE, E, SE, S, SW, W, NW })) {
traverse(m_pos, direction, Add(state.board, m_allowed, m_color));
}
}
void King::reinit(const State& state) {
if (m_pos == INVALID_POS) return;
if (!m_update) return;
m_update = false;
m_allowed.clear();
auto& enemyPieces = m_color ? state.whites : state.blacks;
auto& pawnAttackLeft = m_color ? SW : NW;
auto& pawnAttackRight = m_color ? SE : NE;
for (auto& direction : { N, NE, E, SE, S, SW, W, NW }) {
auto pos = makeStep(m_pos, direction);
bool accept = pos != INVALID_POS && !(state.board[pos] && state.board[pos]->color() == m_color);
if (accept) {
for (auto& p : enemyPieces) {
if (!p.movePattern(pos)) continue;
if (p.type() == Piece::Knight || p.type() == Piece::King) {
accept = false;
break;
}
else if (p.type() == Piece::Pawn) {
auto from = positions[pos];
auto to = p.coord();
Direction d {char(to[R] - from[R]), char(to[F] - from[F])};
if (d == pawnAttackLeft || d == pawnAttackRight) {
accept = false;
break;
}
}
else {
auto from = positions[pos];
auto to = p.coord();
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
auto reached = traverse(pos, d, Find(state.board));
if (p.pos() == reached) {
accept = false;
break;
}
}
}
}
if (accept) m_allowed.insert(pos);
}
}
const char* Piece::name() const {
static_assert(pieceNames.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
return pieceNames[m_type];
}
char Piece::initial() const {
static_assert(blackShort.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
static_assert(whiteShort.size() == Piece::NUM_PIECES, "Mismatch between piece names and types");
return m_color ? blackShort[m_type] : whiteShort[m_type];
}
void Piece::invalidate() {
m_update = true;
}
const char* Piece::coord() const {
if (m_pos == INVALID_POS) return "";
return positions[m_pos];
}
bool Piece::canReach(char pos) const {
return movePattern(pos) && m_allowed.count(pos);
}
void Piece::take() {
m_pos = INVALID_POS;
m_allowed = {};
}
State::State()
: blacks {
{Piece::Black, "a7"_P, {"a5"_P, "a6"_P} },
{Piece::Black, "b7"_P, {"b5"_P, "b6"_P} },
{Piece::Black, "c7"_P, {"c5"_P, "c6"_P} },
{Piece::Black, "d7"_P, {"d5"_P, "d6"_P} },
{Piece::Black, "e7"_P, {"e5"_P, "e6"_P} },
{Piece::Black, "f7"_P, {"f5"_P, "f6"_P} },
{Piece::Black, "g7"_P, {"g5"_P, "g6"_P} },
{Piece::Black, "h7"_P, {"h5"_P, "h6"_P} },
{Piece::Black, "a8"_P},
{Piece::Black, "b8"_P, {"a6"_P, "c6"_P} },
{Piece::Black, "c8"_P},
{Piece::Black, "d8"_P},
{Piece::Black, "e8"_P},
{Piece::Black, "f8"_P},
{Piece::Black, "g8"_P, {"f6"_P, "h6"_P} },
{Piece::Black, "h8"_P},
}
, whites {
{Piece::White, "a2"_P, {"a3"_P, "a4"_P} },
{Piece::White, "b2"_P, {"b3"_P, "b4"_P} },
{Piece::White, "c2"_P, {"c3"_P, "c4"_P} },
{Piece::White, "d2"_P, {"d3"_P, "d4"_P} },
{Piece::White, "e2"_P, {"e3"_P, "e4"_P} },
{Piece::White, "f2"_P, {"f3"_P, "f4"_P} },
{Piece::White, "g2"_P, {"g3"_P, "g4"_P} },
{Piece::White, "h2"_P, {"h3"_P, "h4"_P} },
{Piece::White, "a1"_P},
{Piece::White, "b1"_P, {"a3"_P, "c3"_P} },
{Piece::White, "c1"_P},
{Piece::White, "d1"_P},
{Piece::White, "e1"_P},
{Piece::White, "f1"_P},
{Piece::White, "g1"_P, {"f3"_P, "h3"_P} },
{Piece::White, "h1"_P},
}
, board {{
&whites[ 8], &whites[ 9], &whites[10], &whites[11], &whites[12], &whites[13], &whites[14], &whites[15],
&whites[ 0], &whites[ 1], &whites[ 2], &whites[ 3], &whites[ 4], &whites[ 5], &whites[ 6], &whites[ 7],
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
&blacks[ 0], &blacks[ 1], &blacks[ 2], &blacks[ 3], &blacks[ 4], &blacks[ 5], &blacks[ 6], &blacks[ 7],
&blacks[ 8], &blacks[ 9], &blacks[10], &blacks[11], &blacks[12], &blacks[13], &blacks[14], &blacks[15],
}}
{}
Chessboard::Chessboard()
: m_state(new State())
{
setGrammar();
}
Chessboard::~Chessboard() = default;
void Chessboard::setPrompt(const std::string& prompt) {
m_prompt = prompt;
setGrammar();
}
void Chessboard::setGrammar() {
m_grammar.clear();
std::string result;
if (m_prompt.empty()) {
result += "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n";
//result += "move ::= \" \" frompos \" \" \"to \"? topos\n";
}
else {
// result += "move ::= prompt \" \" ((piece | frompos) \" \" \"to \"?)? topos\n"
result += "move ::= prompt \" \" frompos \" \" \"to \"? topos\n"
"prompt ::= \" " + m_prompt + "\"\n";
}
std::set<Piece::Types> pieceTypes;
std::set<char> from_pos;
std::set<char> to_pos;
auto& pieces = m_moveCounter % 2 ? m_state->blacks : m_state->whites;
std::set<size_t> flags;
for (auto& p : pieces) {
if (p.allowed().empty()) continue;
bool addPiece = false;
if (!m_inCheck || p.type() == Piece::King) {
to_pos.insert(p.allowed().begin(), p.allowed().end());
addPiece = !p.allowed().empty();
}
else {
for (auto move : p.allowed()) {
if (m_allowedInCheck.count(move)) {
to_pos.insert(move);
addPiece = true;
}
}
}
if (addPiece) {
pieceTypes.insert(p.type());
from_pos.insert(p.pos());
}
}
if (pieceTypes.empty()) return;
result += "piece ::= (";
for (auto& p : pieceTypes) result += " \"" + std::string(pieceNames[p]) + "\" |";
result.pop_back();
result += ")\n\n";
result += "frompos ::= (";
for (auto& p : from_pos) result += " \"" + std::string(positions[p]) + "\" |";
result.pop_back();
result += ")\n";
result += "topos ::= (";
for (auto& p : to_pos) result += " \"" + std::string(positions[p]) + "\" |";
result.pop_back();
result += ")\n";
m_grammar = std::move(result);
}
std::string Chessboard::stringifyBoard() {
std::string result;
result.reserve(16 + 2 * 64 + 16);
for (char rank = 'a'; rank <= 'h'; ++rank) {
result.push_back(rank);
result.push_back(' ');
}
result.back() = '\n';
for (int i = 7; i >= 0; --i) {
for (int j = 0; j < 8; ++j) {
auto p = m_state->board[i * 8 + j];
if (p) result.push_back(p->initial());
else result.push_back((i + j) % 2 ? '.' : '*');
result.push_back(' ');
}
result.push_back('0' + i + 1);
result.push_back('\n');
}
return result;
}
std::string Chessboard::process(const std::string& command) {
const auto t_start = std::chrono::high_resolution_clock::now();
auto color = Piece::Colors(m_moveCounter % 2);
Piece* piece = nullptr;
auto pos_to = INVALID_POS;
if (!parseCommand(command, piece, pos_to)) return "";
auto pos_from = piece->pos();
if (!move(*piece, pos_to)) return "";
flagUpdates(pos_from, pos_to);
detectChecks();
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
for (auto& p : enemyPieces) p.reinit(*m_state); // only enemy moves needed next
std::string result = {positions[pos_from][R], positions[pos_from][F], '-', positions[pos_to][R], positions[pos_to][F]};
++m_moveCounter;
setGrammar();
const auto t_end = std::chrono::high_resolution_clock::now();
auto t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
fprintf(stdout, "%s: Move '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", result.data(), "\033[0m", (int) t_ms);
if (m_grammar.empty()) result.push_back('#');
return result;
}
bool Chessboard::parseCommand(const std::string& command, Piece*& piece, char& pos_to) {
auto color = Piece::Colors(m_moveCounter % 2);
fprintf(stdout, "%s: Command to %s: '%s%.*s%s'\n", __func__, (color ? "Black" : "White"), "\033[1m", int(command.size()), command.data(), "\033[0m");
if (command.empty()) return false;
auto tokens = split(command, ' ');
auto pos_from = INVALID_POS;
auto type = Piece::Types::NUM_PIECES;
if (tokens.size() == 1) {
type = Piece::Types::Pawn;
pos_to = strToPos(tokens.front());
}
else {
pos_from = strToPos(tokens.front());
if (pos_from == INVALID_POS) type = Piece::Types(strToType(tokens.front()));
pos_to = strToPos(tokens.back());
}
if (pos_to == INVALID_POS) return false;
if (pos_from == INVALID_POS) {
if (type == Piece::Types::NUM_PIECES) return false;
auto& pieces = color ? m_state->blacks : m_state->whites;
for (auto& p : pieces) {
if (p.type() == type && p.canReach(pos_to)) {
pos_from = p.pos();
break;
}
}
}
if (pos_from == INVALID_POS) return false;
if (m_state->board[pos_from] == nullptr) return false;
piece = m_state->board[pos_from];
if (piece->color() != color) return false;
return true;
}
void Chessboard::flagUpdates(char pos_from, char pos_to) {
auto color = Piece::Colors(m_moveCounter % 2);
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
auto& ownPieces = color ? m_state->blacks : m_state->whites;
for (auto& p : enemyPieces) {
if (p.movePattern(pos_to) || p.movePattern(pos_from)) {
updatePins(p);
p.invalidate();
}
}
for (auto& p : ownPieces) {
if (p.movePattern(pos_to) || p.movePattern(pos_from)) {
updatePins(p);
p.invalidate();
}
}
}
void Chessboard::updatePins(Piece& piece) {
if (piece.type() == Piece::Pawn || piece.type() == Piece::Knight || piece.type() == Piece::King) return;
auto& enemyPieces = piece.color() ? m_state->whites : m_state->blacks;
auto& enemyPins = piece.color() ? m_state->whitePins : m_state->blackPins;
auto& king = enemyPieces.k;
auto it = std::find_if(enemyPins.begin(), enemyPins.end(), [&] (const Pin& pin) { return pin.pinner == &piece; });
if (it != enemyPins.end()) {
it->pinned->invalidate();
enemyPins.erase(it);
}
if (piece.movePattern(king.pos())) {
auto to = positions[king.pos()];
auto from = piece.coord();
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
auto reached = traverse(piece.pos(), d, Find(m_state->board));
auto foundPiece = m_state->board[reached];
if (&king == foundPiece) {
// check
king.invalidate();
}
else if (foundPiece && foundPiece->color() != piece.color()) {
reached = traverse(reached, d, Find(m_state->board));
if (&king == m_state->board[reached]) {
enemyPins.push_back({d, &piece, foundPiece});
foundPiece->invalidate();
}
}
}
}
void Chessboard::detectChecks() {
auto color = Piece::Colors(m_moveCounter % 2);
auto& enemyPieces = color ? m_state->whites : m_state->blacks;
auto& ownPieces = color ? m_state->blacks : m_state->whites;
auto& king = enemyPieces.k;
auto& pawnAttackLeft = color ? SW : NW;
auto& pawnAttackRight = color ? SE : NE;
for (auto& p : ownPieces) {
if (!p.movePattern(king.pos())) continue;
auto to = positions[king.pos()];
auto from = p.coord();
if (p.type() == Piece::Knight) {
if (!m_inCheck) {
m_allowedInCheck = { p.pos() };
}
else {
m_allowedInCheck.clear();
}
m_inCheck = true;
}
else if (p.type() == Piece::Pawn) {
Direction d {char(to[R] - from[R]), char(to[F] - from[F])};
if (d == pawnAttackLeft || d == pawnAttackRight) {
if (!m_inCheck) {
m_allowedInCheck = { p.pos() };
}
else {
m_allowedInCheck.clear();
}
m_inCheck = true;
}
}
else {
Direction d = normalize({char(to[R] - from[R]), char(to[F] - from[F])});
std::set<char> tmp;
auto pos = traverse(p.pos(), d, Add(m_state->board, tmp, king.color()));
if (pos == king.pos()) {
tmp.insert(p.pos());
if (!m_inCheck) {
m_allowedInCheck = std::move(tmp);
}
else {
m_allowedInCheck.clear();
}
m_inCheck = true;
}
}
}
}
bool Chessboard::move(Piece& piece, char pos_to) {
auto& allowed = piece.allowed();
if (allowed.count(pos_to) == 0 || (m_inCheck && piece.type() != Piece::King && m_allowedInCheck.count(pos_to) == 0)) return false;
if (m_state->board[pos_to] && m_state->board[pos_to]->color() == piece.color()) return false;
if (m_state->board[pos_to]) m_state->board[pos_to]->take();
m_state->board[piece.pos()] = nullptr;
m_state->board[pos_to] = &piece;
piece.setPos(pos_to);
m_inCheck = false;
m_allowedInCheck.clear();
return true;
}

View File

@ -0,0 +1,33 @@
#pragma once
#include <string>
#include <set>
#include <memory>
// just basic validation
// fixme: missing en passant, castling, promotion, etc.
struct State;
class Piece;
class Chessboard {
public:
Chessboard();
~Chessboard();
std::string process(const std::string& command);
std::string stringifyBoard();
const std::string& grammar() { return m_grammar; }
const std::string& prompt() { return m_prompt; }
void setPrompt(const std::string& prompt);
private:
bool parseCommand(const std::string& command, Piece*& piece, char& pos_to);
bool move(Piece& piece, char pos);
void flagUpdates(char pos_from, char pos_to);
void updatePins(Piece& piece);
void detectChecks();
void setGrammar();
std::unique_ptr<State> m_state;
std::set<char> m_allowedInCheck;
bool m_inCheck = false;
int m_moveCounter = 0;
std::string m_grammar;
std::string m_prompt;
};

View File

@ -0,0 +1,193 @@
#include "WChess.h"
#include "Chessboard.h"
#include "grammar-parser.h"
#include "common.h"
#include <thread>
WChess::WChess(whisper_context * ctx,
const whisper_full_params & wparams,
callbacks cb,
settings s)
: m_ctx(ctx)
, m_wparams(wparams)
, m_cb(cb)
, m_settings(s)
, m_board(new Chessboard())
{}
WChess::~WChess() = default;
void WChess::set_move(const std::string& moves, float prob) const {
if (m_cb.set_move) (*m_cb.set_move)(moves, prob);
}
void WChess::set_grammar(const std::string& grammar) const {
if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar);
}
bool WChess::get_audio(std::vector<float>& pcmf32) const {
if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32);
return false;
}
std::string WChess::stringify_board() const {
return m_board->stringifyBoard();
}
std::string WChess::get_grammar() const {
return m_board->grammar();
}
void WChess::run() {
bool have_prompt = true;
bool ask_prompt = !have_prompt;
float logprob_min = 0.0f;
float logprob_sum = 0.0f;
int n_tokens = 0;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string k_prompt = have_prompt ? "" : "rook to d4, f3";
int64_t t_ms = 0;
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
while (get_audio(pcmf32_cur)) {
if (!pcmf32_cur.empty()) {
// fprintf(stdout, "%s: Processing ...\n", __func__);
if (!have_prompt) {
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
const float sim = similarity(txt, k_prompt);
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
ask_prompt = true;
} else {
fprintf(stdout, "\n");
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
fprintf(stdout, "\n");
// save the audio for the prompt
pcmf32_prompt = pcmf32_cur;
have_prompt = true;
m_board->setPrompt(k_prompt);
}
} else {
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE;
if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f);
// fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str());
auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str());
auto grammar_rules = grammar_parsed.c_rules();
m_wparams.grammar_rules = grammar_rules.data();
m_wparams.n_grammar_rules = grammar_rules.size();
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
const float p = 100.0f * std::exp(logprob_min);
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
// find the prompt in the text
float best_sim = 0.0f;
size_t best_len = 0;
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
const auto prompt = txt.substr(0, n);
const float sim = similarity(prompt, k_prompt);
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
if (sim > best_sim) {
best_sim = sim;
best_len = n;
}
}
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
std::string command = ::trim(txt.substr(best_len));
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "\n");
if (!command.empty()) {
set_move(m_board->process(command), p);
set_grammar(m_board->grammar());
}
if (m_board->grammar().empty()) {
fprintf(stdout, "%s: No more moves possible\n", __func__);
break;
}
}
}
if (ask_prompt) {
fprintf(stdout, "\n");
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
fprintf(stdout, "\n");
ask_prompt = false;
}
}
}
std::string WChess::transcribe(
const std::vector<float> & pcmf32,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
logprob_min = 0.0f;
logprob_sum = 0.0f;
n_tokens = 0;
t_ms = 0;
if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) {
return {};
}
std::string result;
const int n_segments = whisper_full_n_segments(m_ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(m_ctx, i);
result += text;
const int n = whisper_full_n_tokens(m_ctx, i);
for (int j = 0; j < n; ++j) {
const auto token = whisper_full_get_token_data(m_ctx, i, j);
if(token.plog > 0.0f) return {};
logprob_min = std::min(logprob_min, token.plog);
logprob_sum += token.plog;
++n_tokens;
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}

View File

@ -0,0 +1,63 @@
#pragma once
#include "whisper.h"
#include <string>
#include <vector>
#include <memory>
class Chessboard;
class WChess {
public:
using CheckRunningCb = bool (*)();
using GetAudioCb = bool (*)(std::vector<float> &);
using SetMovesCb = void (*)(const std::string &, float);
using SetGrammarCb = void (*)(const std::string &);
using ClearAudioCb = void (*)();
struct callbacks {
GetAudioCb get_audio = nullptr;
SetMovesCb set_move = nullptr;
SetGrammarCb set_grammar = nullptr;
};
struct settings {
int32_t vad_ms = 2000;
int32_t prompt_ms = 5000;
int32_t command_ms = 4000;
float vad_thold = 0.2f;
float freq_thold = 100.0f;
bool print_energy = false;
};
WChess(
whisper_context * ctx,
const whisper_full_params & wparams,
callbacks cb,
settings s
);
~WChess();
void run();
std::string stringify_board() const;
std::string get_grammar() const;
private:
bool get_audio(std::vector<float>& pcmf32) const;
void set_move(const std::string& moves, float prob) const;
void set_grammar(const std::string& grammar) const;
std::string transcribe(
const std::vector<float> & pcmf32,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms);
whisper_context * m_ctx;
whisper_full_params m_wparams;
const callbacks m_cb;
const settings m_settings;
std::unique_ptr<Chessboard> m_board;
};

View File

@ -0,0 +1,117 @@
#include "Chessboard.h"
#define ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
fflush(stderr); \
exit(1); \
} \
} while (0)
int main() {
{
Chessboard chess;
ASSERT(chess.process("pawn to d4") == "d2-d4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("c1 h6") == "c1-h6");
ASSERT(chess.process("queen h4") == "d8-h4");
ASSERT(chess.process("bishop to g5") == "h6-g5");
ASSERT(chess.process("bishop to b4") == "f8-b4");
ASSERT(chess.process("c4") == "");
ASSERT(chess.process("knight c3") == "b1-c3");
ASSERT(chess.process("knight c6") == "b8-c6");
ASSERT(chess.process("f3") == "");
}
{
Chessboard chess;
ASSERT(chess.process("d4") == "d2-d4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("queen h4") == "d8-h4");
ASSERT(chess.process("queen h5") == "d1-h5");
ASSERT(chess.process("f5") == "");
ASSERT(chess.process("g6") == "g7-g6");
ASSERT(chess.process("knight e2") == "g1-e2");
ASSERT(chess.process("f5") == "f7-f5");
ASSERT(chess.process("knight g3") == "e2-g3");
ASSERT(chess.process("g5") == "");
ASSERT(chess.process("king e7") == "e8-e7");
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("g5") == "g6-g5");
}
{
Chessboard chess;
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("c5") == "c7-c5");
ASSERT(chess.process("e5") == "e4-e5");
ASSERT(chess.process("c4") == "c5-c4");
ASSERT(chess.process("e6") == "e5-e6");
ASSERT(chess.process("c3") == "c4-c3");
ASSERT(chess.process("e7") == "");
ASSERT(chess.process("f7") == "e6-f7");
ASSERT(chess.process("d2") == "");
ASSERT(chess.process("king to f7") == "e8-f7");
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("d2") == "c3-d2");
ASSERT(chess.process("f5") == "");
ASSERT(chess.process("king to e2") == "e1-e2");
ASSERT(chess.process("king to g6") == "f7-g6");
ASSERT(chess.process("f5") == "f4-f5");
ASSERT(chess.process("e6") == "");
ASSERT(chess.process("king to h5") == "g6-h5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("king to g5") == "h5-g5");
ASSERT(chess.process("h4") == "h2-h4");
ASSERT(chess.process("king to h5") == "");
ASSERT(chess.process("king to g6") == "");
ASSERT(chess.process("king to h6") == "g5-h6");
ASSERT(chess.process("bishop to d2") == "c1-d2");
ASSERT(chess.process("king to g5") == "");
ASSERT(chess.process("g5") == "g7-g5");
}
{
Chessboard chess;
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("queen to h4") == "d8-h4#");
ASSERT(chess.process("knight f3") == "");
ASSERT(chess.grammar().empty());
}
{
Chessboard chess;
ASSERT(chess.process("f4") == "f2-f4");
ASSERT(chess.process("e5") == "e7-e5");
ASSERT(chess.process("g4") == "g2-g4");
ASSERT(chess.process("d5") == "d7-d5");
ASSERT(chess.process("g1 f3") == "g1-f3");
ASSERT(chess.process("queen to h4") == "d8-h4");
ASSERT(!chess.grammar().empty());
}
{
Chessboard chess;
ASSERT(chess.process("knight c3") == "b1-c3");
ASSERT(chess.process("knight c6") == "b8-c6");
ASSERT(chess.process("knight b5") == "c3-b5");
ASSERT(chess.process("knight f6") == "g8-f6");
ASSERT(chess.process("knight d6") == "b5-d6");
ASSERT(chess.process("knight d4") == "");
ASSERT(chess.process("d6") == "c7-d6");
ASSERT(chess.process("e4") == "e2-e4");
ASSERT(chess.process("knight d4") == "c6-d4");
ASSERT(chess.process("d3") == "d2-d3");
ASSERT(chess.process("knight e4") == "f6-e4");
ASSERT(chess.process("king to e2") == "");
ASSERT(chess.process("king to d2") == "");
}
}

View File

@ -0,0 +1,8 @@
if (WHISPER_SDL2)
set(TARGET wchess)
add_executable(${TARGET} wchess.cmd.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE wchess-core common-sdl ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -0,0 +1,247 @@
// Command line voice assisted chess
//
// Speak chess move commands to the microphone.
// The moves will translated to chessboard positions.
//
//
#include "WChess.h"
#include "common-sdl.h"
#include <iostream>
#include <memory>
#include <thread>
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t prompt_ms = 5000;
int32_t command_ms = 8000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float grammar_penalty = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
std::string language = "en";
std::string model = "models/ggml-base.en.bin";
std::string fname_out;
std::string commands;
std::string prompt;
std::string context;
std::string grammar;
};
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms);
fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); }
else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
std::unique_ptr<WChess> g_wchess;
int g_moveCount = 0;
void set_move(const std::string & move, float) {
if (!move.empty()) {
g_moveCount++;
fprintf(stdout, "Move: %s\n\n", move.c_str());
}
else fprintf(stdout, "Move rejected\n\n");
fprintf(stdout, "%s\n", g_wchess->stringify_board().c_str());
fprintf(stdout, "%s\n", g_moveCount ? "White's turn" : "Black's turn");
}
audio_async g_audio(30*1000);
bool g_listening = false;
std::vector<float> g_pcmf32;
bool read_input() {
std::string input;
while (true) {
fprintf(stdout, "[(l)isten/(p)ause/(q)uit]: ");
std::cin >> input;
fprintf(stdout, "\n");
if (input[0] == 'q') {
fprintf(stdout, "Quitting\n");
return false;
}
if (input[0] == 'l') {
if (!g_listening) {
fprintf(stdout, "Listening\n");
g_listening = true;
g_pcmf32.clear();
g_audio.resume();
g_audio.clear();
}
else fprintf(stdout, "Still listening\n");
return true;
}
else {
if (g_listening) {
g_listening = false;
g_audio.get(0, g_pcmf32);
g_audio.pause();
fprintf(stdout, "Processing\n");
}
else fprintf(stdout, "Not listening\n");
return true;
}
}
return true;
}
bool get_audio(std::vector<float> & pcmf32_cur) {
if (!read_input()) return false;
if (!g_pcmf32.empty()) pcmf32_cur = std::move(g_pcmf32);
else pcmf32_cur.clear();
return true;
}
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context_params cparams;
cparams.use_gpu = params.use_gpu;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (!ctx) {
fprintf(stderr, "%s: whisper_init_from_file_with_params() failed!\n", __func__);
return 1;
}
// init audio
if (!g_audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.no_timestamps = true;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.temperature = 0.0f;
wparams.temperature_inc = 2.0f;
wparams.greedy.best_of = 1;
wparams.beam_search.beam_size = 1;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.initial_prompt = params.context.data();
WChess::callbacks cb;
cb.get_audio = get_audio;
cb.set_move = set_move;
WChess::settings s;
s.vad_ms = 2000;
s.prompt_ms = params.prompt_ms;
s.command_ms = params.command_ms;
s.vad_thold = params.vad_thold;
s.freq_thold = params.freq_thold;
s.print_energy = params.print_energy;
g_wchess.reset(new WChess(ctx, wparams, cb, s));
set_move("start", 0);
g_wchess->run();
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}

View File

@ -0,0 +1,51 @@
set(TARGET wchess.wasm)
add_executable(${TARGET}
wchess.wasm.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
common
wchess-core
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside chess.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/${TARGET}.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/chess.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1024MB \
-s TOTAL_MEMORY=1024MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory
${CMAKE_CURRENT_SOURCE_DIR}/chessboardjs-1.0.0
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_SOURCE_DIR}/jquery-3.7.1.min.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/
)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_SOURCE_DIR}/examples/helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/helpers.js @ONLY)

View File

@ -0,0 +1,54 @@
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
.clearfix-7da63 {
clear: both;
}
.board-b72b1 {
border: 2px solid #404040;
box-sizing: content-box;
}
.square-55d63 {
float: left;
position: relative;
/* disable any native browser highlighting */
-webkit-touch-callout: none;
-webkit-user-select: none;
-khtml-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}
.white-1e1d7 {
background-color: #f0d9b5;
color: #b58863;
}
.black-3c85d {
background-color: #b58863;
color: #f0d9b5;
}
.highlight1-32417, .highlight2-9c5d2 {
box-shadow: inset 0 0 3px 3px yellow;
}
.notation-322f9 {
cursor: default;
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
font-size: 14px;
position: absolute;
}
.alpha-d2270 {
bottom: 1px;
right: 3px;
}
.numeric-fc462 {
top: 2px;
left: 2px;
}

View File

@ -0,0 +1,2 @@
/*! chessboard.js v1.0.0 | (c) 2019 Chris Oakman | MIT License chessboardjs.com/license */
.clearfix-7da63{clear:both}.board-b72b1{border:2px solid #404040;box-sizing:content-box}.square-55d63{float:left;position:relative;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}.white-1e1d7{background-color:#f0d9b5;color:#b58863}.black-3c85d{background-color:#b58863;color:#f0d9b5}.highlight1-32417,.highlight2-9c5d2{box-shadow:inset 0 0 3px 3px #ff0}.notation-322f9{cursor:default;font-family:"Helvetica Neue",Helvetica,Arial,sans-serif;font-size:14px;position:absolute}.alpha-d2270{bottom:1px;right:3px}.numeric-fc462{top:2px;left:2px}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 777 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 748 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,32 @@
# chessboard.js Change Log
All notable changes to this project will be documented in this file.
## [1.0.0] - 2019-06-11
- Orientation methods now return current orientation. [Issue #64]
- Drop support for IE8
- Do not check for `window.JSON` (Error #1004)
- Rename `ChessBoard` to `Chessboard` (`ChessBoard` is still supported, however)
- id query selectors are now supported as the first argument to `Chessboard()`
- Remove Error #1002
- Format code according to [StandardJS]
- Bump minimum jQuery version to 1.8.3
- Throttle piece drag functions
## [0.3.0] - 2013-08-10
- Added `appearSpeed` animation config property
- Added `onSnapbackEnd` event
- Added `onMoveEnd` event
## [0.2.0] - 2013-08-05
- Added `onMouseoverSquare` and `onMouseoutSquare` events
- Added `onSnapEnd` event
- Added square code as CSS class on the squares
- Added [chess.js] integration examples
## [0.1.0] - 2013-05-21
- Initial release
[chess.js]:https://github.com/jhlywa/chess.js
[Issue #64]:https://github.com/oakmac/chessboardjs/issues/64
[StandardJS]:https://standardjs.com/

View File

@ -0,0 +1,20 @@
Copyright 2019 Chris Oakman
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,82 @@
# chessboard.js
chessboard.js is a JavaScript chessboard component. It depends on [jQuery].
Please see [chessboardjs.com] for documentation and examples.
## What is chessboard.js?
chessboard.js is a JavaScript chessboard component with a flexible "just a
board" API that
chessboard.js is a standalone JavaScript Chess Board. It is designed to be "just
a board" and expose a powerful API so that it can be used in different ways.
Here's a non-exhaustive list of things you can do with chessboard.js:
- Use chessboard.js to show game positions alongside your expert commentary.
- Use chessboard.js to have a tactics website where users have to guess the best
move.
- Integrate chessboard.js and [chess.js] with a PGN database and allow people to
search and playback games (see [Example 5000])
- Build a chess server and have users play their games out using the
chessboard.js board.
chessboard.js is flexible enough to handle any of these situations with relative
ease.
## What can chessboard.js **not** do?
The scope of chessboard.js is limited to "just a board." This is intentional and
makes chessboard.js flexible for handling a multitude of chess-related problems.
This is a common source of confusion for new users. [remove?]
Specifically, chessboard.js does not understand anything about how the game of
chess is played: how a knight moves, who's turn is it, is White in check?, etc.
Fortunately, the powerful [chess.js] library deals with exactly this sort of
problem domain and plays nicely with chessboard.js's flexible API. Some examples
of chessboard.js combined with chess.js: 5000, 5001, 5002
Please see the powerful [chess.js] library for an API to deal with these sorts
of questions.
This logic is distinct from the logic of the board. Please see the powerful
[chess.js] library for this aspect of your application.
Here is a list of things that chessboard.js is **not**:
- A chess engine
- A legal move validator
- A PGN parser
chessboard.js is designed to work well with any of those things, but the idea
behind chessboard.js is that the logic that controls the board should be
independent of those other problems.
## Docs and Examples
- Docs - <http://chessboardjs.com/docs>
- Examples - <http://chessboardjs.com/examples>
## Developer Tools
```sh
# create a build in the build/ directory
npm run build
# re-build the website
npm run website
```
## License
[MIT License](LICENSE.md)
[jQuery]:https://jquery.com/
[chessboardjs.com]:http://chessboardjs.com
[chess.js]:https://github.com/jhlywa/chess.js
[Example 5000]:http://chessboardjs.com/examples#5000

View File

@ -0,0 +1,29 @@
{
"author": "Chris Oakman <chris@oakmac.com> (http://chrisoakman.com/)",
"name": "@chrisoakman/chessboardjs",
"description": "JavaScript chessboard widget",
"homepage": "https://chessboardjs.com",
"license": "MIT",
"version": "1.0.0",
"repository": {
"type": "git",
"url": "git://github.com/oakmac/chessboardjs.git"
},
"files": ["dist/"],
"dependencies": {
"jquery": ">=3.4.1"
},
"devDependencies": {
"csso": "3.5.1",
"fs-plus": "3.1.1",
"kidif": "1.1.0",
"mustache": "2.3.0",
"standard": "10.0.2",
"uglify-js": "3.6.0"
},
"scripts": {
"build": "standard lib/chessboard.js && node scripts/build.js",
"standard": "standard --fix lib/*.js website/js/*.js",
"website": "node scripts/website.js"
}
}

View File

@ -0,0 +1,499 @@
<!doctype html>
<html lang="en-us">
<head>
<title>wchess : voice-controlled chess using Whisper + WebAssembly</title>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<meta name="viewport" content="width=device-width, initial-scale=0.7, maximum-scale=1, minimum-scale=0.7, user-scalable=no"/>
<meta name="apple-mobile-web-app-capable" content="yes" />
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
.button {
background-color: #000000;
color: #FFFFFF;
padding: 20px;
border-radius: 10px;
-moz-border-radius: 10px;
-webkit-border-radius: 10px;
margin:10px;
width: 100px;
height: 50px;
-webkit-touch-callout: none; /* Safari */
-webkit-user-select: none; /* Chrome */
-moz-user-select: none; /* Firefox */
-ms-user-select: none; /* Internet Explorer/Edge */
user-select: none;
}
button[disabled]{
background-color: #cccccc;
color: #666666;
padding: 20px;
border-radius: 10px;
-moz-border-radius: 10px;
-webkit-border-radius: 10px;
margin:10px;
width: 100px;
}
.center {
display: flex;
justify-content: center;
align-items: center;
width: 500px;
}
#description {
width: 500px;
}
</style>
<link rel="stylesheet" href="css/chessboard-1.0.0.min.css" integrity="sha384-q94+BZtLrkL1/ohfjR8c6L+A6qzNH9R2hBLwyoAfu3i/WCvQjzL2RQJ3uNHDISdU" crossorigin="anonymous">
</head>
<body>
<div id="main-container">
<div id="description">
<b>wchess : voice-controlled chess using Whisper + WebAssembly</b>
<br><br>
This is a demonstration of using Whisper to recognize voice commands in the browser.
<br><br>
Usage:<br>
<ul>
<li>Select a Whisper model</li>
<li>Accept the microphone permission request if prompted</li>
<li>Hold the button and say a chess move (e.g. "Knight to c3")</li>
<li>Release the button and wait for the move to be recognized</li>
<li>Repeat</li>
</ul>
Examples:<br>
<ul>
<li><b>"d4"</b></li>
<li><b>"e2 e4"</b></li>
<li><b>"Knight f3"</b></li>
<li><b>"Bishop to b5"</b></li>
</ul>
Features:<br>
<ul>
<li>Model quantization for reduced memory footprint (~42MB)</li>
<li><a href="https://github.com/ggerganov/whisper.cpp/pull/1229">Grammar-based sampling</a> for improved recognition accuracy</li>
</ul>
<b>
Note that not all chess moves are supported. For example, castling and pawn promotion
currently do not work, but can be easily implemented. There could also be some bugs in
the move handling logic in general. The main reason for that is to keep the implementation
simple. The assumption is that a real application would already have a proper move
validation logic in place.<br><br>
The main purpose of this example is to demonstrate the capabilities of whisper.cpp and
its application in the browser for voice recognition locally on your device.
</b>
<br><br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/wchess">GitHub</a>.
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
</div>
<hr>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper()">tiny.en (Q8_0, 42 MB)</button>
<span id="fetch-whisper-progress"></span>
<br><br>
<button id="clear" onclick="clearCache()">Clear browser cache</button>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<div id="game">
<br>
<div id="chessboard" style="width: 500px"></div>
<script src="js/jquery-3.7.1.min.js"></script>
<script src="js/chessboard-1.0.0.min.js"></script>
<script>
var board = Chessboard('chessboard', 'start')
var move_count = 0;
</script>
<br>
<div id="state">
Status: <b><span id="state-status">select model</span></b>
<div id="input" class="center">
<button id="toggler" class="button" onselectstart="return false" style="display: none">Hold</button>
</div>
<pre id="state-grammar">[The grammar will be displayed here]</pre>
<pre id="state-moves">[The moves will be displayed here]</pre>
</div>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/command.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="js/helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// the command instance
var instance = null;
// model name
var model_whisper = null;
var model_file = null;
var module_ready = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Module initialized successfully!');
module_ready = true;
initInstance();
}
};
function initInstance() {
if (!module_ready || !model_file || instance) return
instance = Module.init(model_file);
if (instance) {
setStatus('Ready');
printTextarea("js: whisper initialized, instance: " + instance);
}
else {
printTextarea("js: failed to initialize whisper");
}
}
function setStatus(text) {
document.getElementById('state-status').innerHTML = text;
}
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
model_file = fname;
initInstance();
}
function loadWhisper() {
setStatus('Loading')
//let url = 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin';
let url = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q8_0.bin';
let dst = 'whisper.bin';
let size_mb = 42;
model_whisper = 'tiny.en-q8_0';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model_whisper + '" ... ';
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
// init audio capture so that the user receives a permission request
{
let context = new AudioContext({
sampleRate: 16000,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
stream.getTracks().forEach(function(track) {
track.stop();
});
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
context.close();
}
document.getElementById('toggler').style.display = 'block';
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
if (mediaRecorder) {
mediaRecorder.stop();
}
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
let audio = renderedBuffer.getChannelData(0);
printTextarea('js: number of samples: ' + audio.length);
Module.set_audio(instance, audio);
});
mediaRecorder = null;
context = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
stream.getTracks().forEach(function(track) {
track.stop();
});
};
mediaRecorder.start();
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
}
//
// main
//
var nLines = 0;
var movesAll = '';
// document.body.addEventListener('keydown', function(event) {
// if (event.keyCode === 32) {
// document.getElementById('toggler').innerText = "";
// onStart();
// }
// }, true);
// document.body.addEventListener('keyup', function(event) {
// if (event.keyCode === 32) {
// document.getElementById('toggler').innerText = "Hold";
// onStop();
// }
// }, true);
document.getElementById('toggler').addEventListener("touchstart", function(event){
this.innerText = "";
onStart();
}, true);
document.getElementById('toggler').addEventListener("touchend", function(event){
this.innerText = "Hold";
onStop();
}, true)
document.getElementById('toggler').addEventListener('mousedown', function(event) {
this.innerText = "";
onStart();
}, true);
document.getElementById('toggler').addEventListener('mouseup', function(event) {
this.innerText = "Hold";
onStop();
}, true);
function onStart() {
if (!instance) return;
setStatus('Listening');
startRecording();
}
function onStop() {
setStatus('Processing');
printTextarea('js: stopping recording ...');
stopRecording();
}
function setMove(move, prob) {
if (move != null && move.length > 1) {
let gameOver = move[move.length - 1] === '#';
if (gameOver) {
move = move.substring(0, move.length - 1);
document.getElementById('toggler').disabled = true;
}
board.move(move);
movesAll += move + ', prob = ' + prob.toFixed(2) + '% <br>';
nLines++;
// if more than 10 lines, remove the first line
if (nLines > 10) {
var i = movesAll.indexOf('<br>');
if (i > 0) {
movesAll = movesAll.substring(i + 4);
nLines--;
}
}
++move_count;
setStatus(gameOver ? 'Done' : move_count % 2 ? 'Black\'s turn' : 'White\'s turn');
document.getElementById('state-moves').innerHTML = movesAll;
}
else {
setStatus('Failed. ' + (move_count % 2 ? 'Black\'s turn' : 'White\'s turn'));
}
}
function setGrammar(grammar) {
document.getElementById('state-grammar').innerHTML = grammar;
}
</script>
<script type="text/javascript" src="js/chess.js"></script>
</body>
</html>

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,141 @@
#include <WChess.h>
#include <emscripten.h>
#include <emscripten/bind.h>
#include <thread>
constexpr int N_THREAD = 8;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::condition_variable g_cv;
bool g_running(false);
std::vector<float> g_pcmf32;
void set_move(const std::string & move, float prob) {
MAIN_THREAD_EM_ASM({
setMove(UTF8ToString($0), $1)
}, move.c_str(), prob);
}
void set_grammar(const std::string & grammar) {
MAIN_THREAD_EM_ASM({
setGrammar(UTF8ToString($0))
}, grammar.c_str());
}
bool get_audio(std::vector<float> & audio) {
std::unique_lock<std::mutex> lock(g_mutex);
g_cv.wait(lock, [] { return !g_running || !g_pcmf32.empty(); });
if (!g_running) return false;
audio = std::move(g_pcmf32);
return true;
}
void wchess_main(size_t i) {
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.no_timestamps = true;
wparams.max_tokens = 32;
wparams.audio_ctx = 1280; // partial encoder context for better performance
wparams.temperature = 0.0f;
wparams.temperature_inc = 2.0f;
wparams.greedy.best_of = 1;
wparams.beam_search.beam_size = 1;
wparams.language = "en";
wparams.grammar_penalty = 100.0;
wparams.initial_prompt = "bishop to c3, rook to d4, knight to e5, d4 d5, knight to c3, c3, queen to d4, king b1, pawn to a1, bishop to b2, knight to c3,";
printf("command: using %d threads\n", wparams.n_threads);
WChess::callbacks cb;
cb.get_audio = get_audio;
cb.set_move = set_move;
cb.set_grammar = set_grammar;
WChess(g_contexts[i], wparams, cb, {}).run();
if (i < g_contexts.size()) {
whisper_free(g_contexts[i]);
g_contexts[i] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
wchess_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t /* index */) {
{
std::unique_lock<std::mutex> lock(g_mutex);
g_running = false;
}
g_cv.notify_one();
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
g_cv.notify_one();
return 0;
}));
}

View File

@ -206,6 +206,7 @@ void AudioInputCallback(void * inUserData,
params.offset_ms = 0;
params.no_context = true;
params.single_segment = self->stateInp.isRealtime;
params.no_timestamps = params.single_segment;
CFTimeInterval startTime = CACurrentMediaTime();

View File

@ -8,15 +8,15 @@ enum WhisperError: Error {
// Meet Whisper C++ constraint: Don't access from more than one thread at a time.
actor WhisperContext {
private var context: OpaquePointer
init(context: OpaquePointer) {
self.context = context
}
deinit {
whisper_free(context)
}
func fullTranscribe(samples: [Float]) {
// Leave 2 processors free (i.e. the high-efficiency cores).
let maxThreads = max(1, min(8, cpuCount() - 2))
@ -24,17 +24,17 @@ actor WhisperContext {
var params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY)
"en".withCString { en in
// Adapted from whisper.objc
params.print_realtime = true
params.print_progress = false
params.print_realtime = true
params.print_progress = false
params.print_timestamps = true
params.print_special = false
params.translate = false
params.language = en
params.n_threads = Int32(maxThreads)
params.offset_ms = 0
params.no_context = true
params.single_segment = false
params.print_special = false
params.translate = false
params.language = en
params.n_threads = Int32(maxThreads)
params.offset_ms = 0
params.no_context = true
params.single_segment = false
whisper_reset_timings(context)
print("About to run whisper_full")
samples.withUnsafeBufferPointer { samples in
@ -46,7 +46,7 @@ actor WhisperContext {
}
}
}
func getTranscription() -> String {
var transcription = ""
for i in 0..<whisper_full_n_segments(context) {
@ -54,7 +54,7 @@ actor WhisperContext {
}
return transcription
}
static func createContext(path: String) throws -> WhisperContext {
var params = whisper_context_default_params()
#if targetEnvironment(simulator)

View File

@ -17,12 +17,12 @@ else
encoder_only=$2
fi
models=( \
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
models=( \
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" "medium-dis" \
"large-v2" "large-v2-q4_0" "large-v2-q4_1" "large-v2-q5_0" "large-v2-q5_1" "large-v2-q8_0" "large-v2-dis" \
)
if [ "$encoder_only" -eq 0 ]; then
@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
printf "\n"
fi
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
for model in "${models[@]}"; do
# actual run
@ -95,6 +95,6 @@ for model in "${models[@]}"; do
commit=$(git rev-parse --short HEAD)
if [ $ret -eq 0 ]; then
printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
printf "| <todo> | <todo> | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
fi
done

View File

@ -61,7 +61,9 @@ models = [
"ggml-small.bin",
"ggml-medium.en.bin",
"ggml-medium.bin",
"ggml-large.bin",
"ggml-large-v1.bin",
"ggml-large-v2.bin",
"ggml-large-v3.bin",
]

View File

@ -1,6 +1,6 @@
#!/bin/bash
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large" )
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" )
for model in "${models[@]}"; do
python3 models/convert-pt-to-ggml.py ~/.cache/whisper/$model.pt ../whisper models/

154
extra/sync-ggml-am.sh Executable file
View File

@ -0,0 +1,154 @@
#!/bin/bash
#
# Synchronize ggml changes to whisper.cpp
#
# Usage:
#
# $ cd /path/to/whisper.cpp
# $ ./extra/sync-ggml-am.sh
#
set -e
sd=$(dirname $0)
cd $sd/../
SRC_WHISPER=$(pwd)
SRC_GGML=$(cd ../ggml; pwd)
if [ ! -d $SRC_GGML ]; then
echo "ggml not found at $SRC_GGML"
exit 1
fi
lc=$(cat $SRC_WHISPER/extra/sync-ggml.last)
echo "Syncing ggml changes since commit $lc"
cd $SRC_GGML
git log --oneline $lc..HEAD
git log --oneline $lc..HEAD --reverse | grep -v "(whisper/[0-9]*)" | cut -d' ' -f1 > $SRC_WHISPER/ggml-commits
if [ ! -s $SRC_WHISPER/ggml-commits ]; then
rm -v $SRC_WHISPER/ggml-commits
echo "No new commits"
exit 0
fi
if [ -f $SRC_WHISPER/ggml-src.patch ]; then
rm -v $SRC_WHISPER/ggml-src.patch
fi
while read c; do
git format-patch -k $c~1..$c --stdout -- \
include/ggml/ggml*.h \
src/ggml*.h \
src/ggml*.c \
src/ggml*.cpp \
src/ggml*.m \
src/ggml*.metal \
src/ggml*.cu \
examples/common.h \
examples/common.cpp \
examples/common-ggml.h \
examples/common-ggml.cpp \
examples/whisper/whisper.h \
examples/whisper/whisper.cpp \
examples/whisper/main.cpp \
examples/whisper/quantize.cpp \
>> $SRC_WHISPER/ggml-src.patch
done < $SRC_WHISPER/ggml-commits
rm -v $SRC_WHISPER/ggml-commits
# delete files if empty
if [ ! -s $SRC_WHISPER/ggml-src.patch ]; then
rm -v $SRC_WHISPER/ggml-src.patch
fi
cd $SRC_WHISPER
if [ -f $SRC_WHISPER/ggml-src.patch ]; then
# replace PR numbers
#
# Subject: some text (#1234)
# Subject: some text (ggml/1234)
cat ggml-src.patch | sed -e 's/^Subject: \(.*\) (#\([0-9]*\))/Subject: \1 (ggml\/\2)/' > ggml-src.patch.tmp
mv ggml-src.patch.tmp ggml-src.patch
cat ggml-src.patch | sed -e 's/^\(.*\) (#\([0-9]*\))$/\1 (ggml\/\2)/' > ggml-src.patch.tmp
mv ggml-src.patch.tmp ggml-src.patch
# replace filenames:
#
# src/ggml.c -> ggml.c
# src/ggml-alloc.c -> ggml-alloc.c
# src/ggml-backend-impl.h -> ggml-backend-impl.h
# src/ggml-backend.c -> ggml-backend.c
# src/ggml-cuda.cu -> ggml-cuda.cu
# src/ggml-cuda.h -> ggml-cuda.h
# src/ggml-impl.h -> ggml-impl.h
# src/ggml-metal.h -> ggml-metal.h
# src/ggml-metal.m -> ggml-metal.m
# src/ggml-mpi.h -> ggml-mpi.h
# src/ggml-mpi.c -> ggml-mpi.c
# src/ggml-opencl.cpp -> ggml-opencl.cpp
# src/ggml-opencl.h -> ggml-opencl.h
# src/ggml-quants.c -> ggml-quants.c
# src/ggml-quants.h -> ggml-quants.h
# include/ggml/ggml.h -> ggml.h
# include/ggml/ggml-alloc.h -> ggml-alloc.h
# include/ggml/ggml-backend.h -> ggml-backend.h
#
# examples/common.h -> examples/common.h
# examples/common.cpp -> examples/common.cpp
# examples/common-ggml.h -> examples/common-ggml.h
# examples/common-ggml.cpp -> examples/common-ggml.cpp
#
# examples/whisper/whisper.h -> whisper.h
# examples/whisper/whisper.cpp -> whisper.cpp
# examples/whisper/main.cpp -> examples/main/main.cpp
# examples/whisper/quantize.cpp -> examples/quantize/quantize.cpp
cat ggml-src.patch | sed \
-e 's/src\/ggml\.c/ggml.c/g' \
-e 's/src\/ggml-alloc\.c/ggml-alloc.c/g' \
-e 's/src\/ggml-backend-impl\.h/ggml-backend-impl.h/g' \
-e 's/src\/ggml-backend\.c/ggml-backend.c/g' \
-e 's/src\/ggml-cuda\.cu/ggml-cuda.cu/g' \
-e 's/src\/ggml-cuda\.h/ggml-cuda.h/g' \
-e 's/src\/ggml-impl\.h/ggml-impl.h/g' \
-e 's/src\/ggml-metal\.h/ggml-metal.h/g' \
-e 's/src\/ggml-metal\.m/ggml-metal.m/g' \
-e 's/src\/ggml-mpi\.h/ggml-mpi.h/g' \
-e 's/src\/ggml-mpi\.c/ggml-mpi.c/g' \
-e 's/src\/ggml-opencl\.cpp/ggml-opencl.cpp/g' \
-e 's/src\/ggml-opencl\.h/ggml-opencl.h/g' \
-e 's/src\/ggml-quants\.c/ggml-quants.c/g' \
-e 's/src\/ggml-quants\.h/ggml-quants.h/g' \
-e 's/include\/ggml\/ggml\.h/ggml.h/g' \
-e 's/include\/ggml\/ggml-alloc\.h/ggml-alloc.h/g' \
-e 's/include\/ggml\/ggml-backend\.h/ggml-backend.h/g' \
-e 's/examples\/common\.h/examples\/common.h/g' \
-e 's/examples\/common\.cpp/examples\/common.cpp/g' \
-e 's/examples\/common-ggml\.h/examples\/common-ggml.h/g' \
-e 's/examples\/common-ggml\.cpp/examples\/common-ggml.cpp/g' \
-e 's/examples\/whisper\/whisper\.h/whisper.h/g' \
-e 's/examples\/whisper\/whisper\.cpp/whisper.cpp/g' \
-e 's/examples\/whisper\/main\.cpp/examples\/main\/main.cpp/g' \
-e 's/examples\/whisper\/quantize\.cpp/examples\/quantize\/quantize.cpp/g' \
> ggml-src.patch.tmp
mv ggml-src.patch.tmp ggml-src.patch
git am ggml-src.patch
rm -v $SRC_WHISPER/ggml-src.patch
fi
# update last commit
cd $SRC_GGML
git log -1 --format=%H > $SRC_WHISPER/extra/sync-ggml.last
echo "Done"
exit 0

1
extra/sync-ggml.last Normal file
View File

@ -0,0 +1 @@
3fd01e00e40583ccd4b393a7c6502d6a4455a1d5

5
extra/sync-llama.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
cp -rpv ../llama.cpp/llama.h ./examples/talk-llama/llama.h
cp -rpv ../llama.cpp/llama.cpp ./examples/talk-llama/llama.cpp
cp -rpv ../llama.cpp/unicode.h ./examples/talk-llama/unicode.h

View File

@ -72,7 +72,7 @@ static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * t
// check if a tensor is allocated by this buffer
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
return tensor->buffer == alloc->buffer;
return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
}
static bool ggml_is_view(struct ggml_tensor * t) {
@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
size_t cur_max = (char*)addr - (char*)alloc->data + size;
size_t cur_max = (char*)addr - (char*)alloc->base + size;
if (cur_max > alloc->max_size) {
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
for (int i = 0; i < 1024; i++) {
@ -168,10 +168,6 @@ static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor *
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
if (!alloc->measure) {
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
}
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
#endif
@ -237,7 +233,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
}
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(data, size);
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
@ -446,18 +442,18 @@ static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * n
return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
}
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view) {
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
ggml_tallocr_t alloc = node_tallocr(galloc, view);
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
view->backend = view->view_src->backend;
view->buffer = view->view_src->buffer;
if (update_backend) {
view->backend = view->view_src->backend;
}
// views are initialized in the alloc buffer rather than the view_src buffer
view->buffer = alloc->buffer;
view->data = (char *)view->view_src->data + view->view_offs;
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
if (!alloc->measure) {
ggml_backend_buffer_init_tensor(alloc->buffer, view);
@ -469,7 +465,7 @@ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
if (node->data == NULL) {
if (ggml_is_view(node)) {
init_view(galloc, node);
init_view(galloc, node, true);
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) {
@ -499,15 +495,14 @@ static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->view_src = view_src;
view_src_hn->n_views += 1;
init_view(galloc, node);
init_view(galloc, node, false);
return;
}
}
else {
} else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->view_src = parent;
p_hn->n_views += 1;
init_view(galloc, node);
init_view(galloc, node, false);
return;
}
}
@ -537,7 +532,7 @@ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
hash_get(galloc, view_src)->n_views += 1;
if (node->buffer == NULL && node->data != NULL) {
// view of a pre-allocated tensor, didn't call init_view() yet
init_view(galloc, node);
init_view(galloc, node, true);
}
}
@ -548,7 +543,7 @@ static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
}
hash_get(galloc, parent)->n_children += 1;
if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
init_view(galloc, parent);
init_view(galloc, parent, true);
}
}
}
@ -663,7 +658,7 @@ size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, st
return max_size;
}
void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_alloct) {
void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
const size_t hash_size = hash_set.size;
GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
@ -686,7 +681,7 @@ void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * grap
// reset hash values
memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
galloc->hash_allocs = hash_node_alloct;
galloc->hash_allocs = hash_node_talloc;
ggml_tallocr_alloc_graph_impl(galloc, graph);
@ -740,6 +735,10 @@ void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
}
void ggml_allocr_free(ggml_allocr_t alloc) {
if (alloc == NULL) {
return;
}
ggml_gallocr_free(alloc->galloc);
ggml_tallocr_free(alloc->talloc);
free(alloc);
@ -764,3 +763,48 @@ size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
}
// utils
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
size_t alignment = ggml_backend_buft_get_alignment(buft);
size_t nbytes = 0;
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL && t->view_src == NULL) {
nbytes += GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
}
}
if (nbytes == 0) {
// all the tensors in the context are already allocated
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->data == NULL) {
if (t->view_src == NULL) {
ggml_tallocr_alloc(tallocr, t);
} else {
ggml_backend_view_init(buffer, t);
}
} else {
if (t->view_src != NULL) {
// view of a pre-allocated tensor
ggml_backend_view_init(buffer, t);
}
}
}
ggml_tallocr_free(tallocr);
return buffer;
}
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
}

View File

@ -8,6 +8,7 @@ extern "C" {
struct ggml_backend;
struct ggml_backend_buffer;
struct ggml_backend_buffer_type;
//
// Legacy API
@ -42,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
// ggml-backend v2 API
//
// Seperate tensor and graph allocator objects
// Separate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API
@ -80,6 +81,12 @@ GGML_API void ggml_gallocr_alloc_graph_n(
struct ggml_hash_set hash_set,
ggml_tallocr_t * hash_node_talloc);
// Utils
// Create a buffer and allocate all the tensors in a ggml_context
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, struct ggml_backend_buffer_type * buft);
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, struct ggml_backend * backend);
#ifdef __cplusplus
}
#endif

View File

@ -12,31 +12,54 @@ extern "C" {
// Backend buffer
//
// buffer type
typedef void * ggml_backend_buffer_type_context_t;
struct ggml_backend_buffer_type_i {
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
// check if tensor data is in host memory
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
bool (*is_host) (ggml_backend_buffer_type_t buft);
};
struct ggml_backend_buffer_type {
struct ggml_backend_buffer_type_i iface;
ggml_backend_buffer_type_context_t context;
};
// buffer
typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i {
void (*free_buffer) (ggml_backend_buffer_t buffer);
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
void (*free_buffer) (ggml_backend_buffer_t buffer);
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
void * (*get_base) (ggml_backend_buffer_t buffer);
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
};
struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface;
ggml_backend_t backend;
struct ggml_backend_buffer_i iface;
ggml_backend_buffer_type_t buft;
ggml_backend_buffer_context_t context;
size_t size;
};
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend * backend,
ggml_backend_buffer_t ggml_backend_buffer_init(
ggml_backend_buffer_type_t buft,
struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context,
size_t size);
//
// Backend
//
@ -49,20 +72,17 @@ extern "C" {
void (*free)(ggml_backend_t backend);
// buffer allocation
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
// get buffer alignment
size_t (*get_alignment)(ggml_backend_t backend);
// tensor data access
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
// (optional) asynchroneous tensor data access
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*synchronize) (ggml_backend_t backend);
// (optional) copy tensor between different backends, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
// (optional) asynchroneous tensor copy
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*synchronize)(ggml_backend_t backend);
// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
@ -70,7 +90,7 @@ extern "C" {
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph without a plan
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
// check if the backend supports an operation
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
@ -82,6 +102,15 @@ extern "C" {
ggml_backend_context_t context;
};
//
// Backend registry
//
typedef ggml_backend_t (*ggml_backend_init_fn)(const char * params, void * user_data);
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -7,41 +7,47 @@
extern "C" {
#endif
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
//
// Backend buffer
//
struct ggml_backend_buffer;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
// buffer type
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
// backend buffer functions
// buffer
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
//
// Backend
//
struct ggml_backend;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
@ -52,11 +58,12 @@ extern "C" {
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); // automatic fallback to sync copy
//
// CPU backend
@ -68,8 +75,27 @@ extern "C" {
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
// Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
#ifdef GGML_USE_CPU_HBM
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
//
// Backend registry
//
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
GGML_API size_t ggml_backend_reg_get_count(void);
GGML_API size_t ggml_backend_reg_find_by_name(const char * name);
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params]
GGML_API const char * ggml_backend_reg_get_name(size_t i);
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
//
// Backend scheduler
@ -131,6 +157,32 @@ extern "C" {
ggml_backend_sched_t sched,
struct ggml_cgraph * graph);
//
// Utils
//
struct ggml_backend_graph_copy {
ggml_backend_buffer_t buffer;
struct ggml_context * ctx_allocated;
struct ggml_context * ctx_unallocated;
struct ggml_cgraph * graph;
};
// Copy a graph to a different backend
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
// Tensor initialization
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,15 @@ GGML_API int ggml_cuda_get_device_count(void);
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
// backend API
GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
#ifdef __cplusplus
}

View File

@ -232,7 +232,7 @@ bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
// return index, asserts if table is full

View File

@ -87,7 +87,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
bool ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
//
// backend API
@ -98,8 +98,17 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +1,18 @@
#include "ggml.h"
#include "ggml-opencl.h"
#include <array>
#include <atomic>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <limits>
#include <sstream>
#include <vector>
#include <limits>
#define CL_TARGET_OPENCL_VERSION 110
#include <clblast.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "ggml.h"
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

View File

@ -6,19 +6,19 @@
extern "C" {
#endif
void ggml_cl_init(void);
GGML_API void ggml_cl_init(void);
void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
void * ggml_cl_host_malloc(size_t size);
void ggml_cl_host_free(void * ptr);
GGML_API void * ggml_cl_host_malloc(size_t size);
GGML_API void ggml_cl_host_free(void * ptr);
void ggml_cl_free_data(const struct ggml_tensor* tensor);
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
#ifdef __cplusplus
}

View File

@ -19,7 +19,7 @@
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#ifdef __POWER9_VECTOR__
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
#include <altivec.h>
#undef bool
#define bool _Bool
@ -407,6 +407,22 @@ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
#define ggml_vld1q_s8_x4 vld1q_s8_x4
#endif
#if !defined(__ARM_FEATURE_DOTPROD)
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
}
#else
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
#endif
#endif
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
@ -1368,7 +1384,12 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f
float max = x[0];
float sum_w = weights[0];
float sum_x = sum_w * x[0];
#ifdef HAVE_BUGGY_APPLE_LINKER
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
for (volatile int i = 1; i < n; ++i) {
#else
for (int i = 1; i < n; ++i) {
#endif
if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i];
float w = weights[i];
@ -2463,32 +2484,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
#endif
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@ -2771,32 +2772,12 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restri
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
#endif
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
@ -2958,32 +2939,12 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
#endif
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@ -3109,7 +3070,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
size_t vl = __riscv_vsetvl_e8m1(qk/2);
// These tempory registers are for masking and shift operations
// These temporary registers are for masking and shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
@ -3270,32 +3231,12 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
#endif
ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
@ -3545,34 +3486,13 @@ void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restri
const int8x16_t y1_0 = vld1q_s8(y1->qs);
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
#else
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
#endif
ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
@ -3645,12 +3565,10 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m3 = vdupq_n_u8(0x3);
const uint8x16_t m4 = vdupq_n_u8(0xF);
#if defined(__ARM_FEATURE_DOTPROD)
const int32x4_t vzero = vdupq_n_s32(0);
#endif
const int32x4_t vzero = vdupq_n_s32(0);
ggml_int8x16x2_t q2bytes;
uint8_t aux[16];
@ -3658,7 +3576,6 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
float sum = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
@ -3672,7 +3589,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@ -3684,20 +3601,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
// We use this macro instead of a function call because for some reason
// the code runs 2-3% slower, even if the function is declared inline
#if defined(__ARM_FEATURE_DOTPROD)
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
#else
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
{\
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
}
#endif
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
@ -3705,26 +3611,23 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index));
for (int j = 0; j < QK_K/128; ++j) {
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
is += 8;
}
sum += d * isum;
sum += d * isum;
}
*s = sum;
@ -4038,11 +3941,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m3 = vdupq_n_u8(0x3);
#if defined(__ARM_FEATURE_DOTPROD)
const int32x4_t vzero = vdupq_n_s32(0);
#endif
const int32x4_t vzero = vdupq_n_s32(0);
ggml_int8x16x4_t q2bytes;
@ -4076,28 +3977,12 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
#if defined(__ARM_FEATURE_DOTPROD)
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
#else
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum1 += vaddvq_s16(p1) * scales[0];
isum2 += vaddvq_s16(p2) * scales[1];
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum1 += vaddvq_s16(p3) * scales[2];
isum2 += vaddvq_s16(p4) * scales[3];
#endif
sum += d * (isum1 + isum2);
}
*s = sum;
@ -4323,9 +4208,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
uint32_t utmp[4];
const uint8x16_t m3b = vdupq_n_u8(0x3);
#ifdef __ARM_FEATURE_DOTPROD
const int32x4_t vzero = vdupq_n_s32(0);
#endif
const uint8x16_t m0 = vdupq_n_u8(1);
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
@ -4377,22 +4260,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
#else
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
#endif
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
scale += 4;
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
@ -4405,22 +4277,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
#else
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
#endif
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
scale += 4;
if (j == 0) {
@ -4752,7 +4613,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
vl = 16;
// retreive lane to multiply with scale
// retrieve lane to multiply with scale
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
@ -4859,10 +4720,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
#ifdef __ARM_FEATURE_DOTPROD
const int32x4_t vzero = vdupq_n_s32(0);
#endif
const int32x4_t vzero = vdupq_n_s32(0);
const uint8x16_t m3b = vdupq_n_u8(0x3);
const uint8x16_t mh = vdupq_n_u8(4);
@ -4903,22 +4761,10 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
#else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3];
#endif
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
sum += d * isum;
@ -5223,11 +5069,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
uint32_t utmp[4];
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
#ifdef __ARM_FEATURE_DOTPROD
const int32x4_t mzero = vdupq_n_s32(0);
#endif
ggml_int8x16x2_t q4bytes;
ggml_int8x16x2_t q8bytes;
@ -5264,44 +5107,22 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
int32_t sumi2 = 0;
for (int j = 0; j < QK_K/64; ++j) {
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
#ifdef __ARM_FEATURE_DOTPROD
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
#else
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
#endif
}
sumf += d * (sumi1 + sumi2);
@ -5598,12 +5419,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
#ifdef __ARM_FEATURE_DOTPROD
const int32x4_t mzero = vdupq_n_s32(0);
#endif
float sumf = 0;
@ -5631,41 +5449,20 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
#ifdef __ARM_FEATURE_DOTPROD
q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const int32_t sumi1 = vaddvq_s32(p1) * scales[0];
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
#else
q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0];
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1];
#endif
sumf += d * (sumi1 + sumi2);
}
*s = sumf - sum_mins;
@ -5870,15 +5667,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
uint32_t utmp[4];
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
#if defined(__ARM_FEATURE_DOTPROD)
const int32x4_t mzero = vdupq_n_s32(0);
#endif
ggml_int8x16x4_t q5bytes;
@ -5933,28 +5726,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
#else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
#endif
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
}
sumf += d * sumi - dmin * sumi_mins;
}
*s = sumf;
@ -6306,12 +6082,9 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t mh = vdupq_n_u8(16);
#if defined(__ARM_FEATURE_DOTPROD)
const int32x4_t mzero = vdupq_n_s32(0);
#endif
ggml_int8x16x4_t q5bytes;
ggml_uint8x16x4_t q5h;
@ -6343,32 +6116,12 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2]));
q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]));
int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1]));
int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]));
int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3]));
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
#else
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1);
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3);
sumf += d*sumi;
#endif
}
*s = sumf;
@ -6595,13 +6348,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);
#if defined(__ARM_FEATURE_DOTPROD)
const int32x4_t vzero = vdupq_n_s32(0);
#endif
//const int8x16_t m32s = vdupq_n_s8(32);
const uint8x16_t mone = vdupq_n_u8(3);
@ -6621,7 +6371,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale);
const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@ -6653,31 +6403,13 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
scale += 4;
#else
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
scale += 2;
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
scale += 2;
#endif
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
shifted = vshrq_n_u8(qhbits.val[0], 4);
@ -6698,34 +6430,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
scale += 4;
//for (int l = 0; l < 4; ++l) {
// const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
// isum += vaddvq_s32(p) * *scale++;
//}
#else
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
scale += 2;
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
scale += 2;
#endif
}
//sum += isum * d_all * y[i].d;
sum += d_all * y[i].d * (isum - 32 * isum_mins);
@ -7071,14 +6780,11 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);
const int8x16_t m32s = vdupq_n_s8(32);
#if defined(__ARM_FEATURE_DOTPROD)
const int32x4_t vzero = vdupq_n_s32(0);
#endif
const uint8x16_t mone = vdupq_n_u8(3);
@ -7114,26 +6820,10 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s);
q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s);
#if defined(__ARM_FEATURE_DOTPROD)
isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
#else
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
#endif
isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
sum += isum * d_all * y[i].d;

1839
ggml.c

File diff suppressed because it is too large Load Diff

152
ggml.h
View File

@ -215,9 +215,9 @@
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4
#define GGML_MAX_PARAMS 1024
#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6
#define GGML_MAX_SRC 10
#define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4
@ -244,11 +244,10 @@
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
fflush(stderr); \
fflush(stdout); \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
ggml_print_backtrace(); \
exit(1); \
abort(); \
} \
} while (0)
@ -256,6 +255,8 @@
#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
#elif defined(__GNUC__)
#define GGML_UNREACHABLE() __builtin_unreachable()
#elif defined(_MSC_VER)
#define GGML_UNREACHABLE() __assume(0)
#else
#define GGML_UNREACHABLE() ((void) 0)
#endif
@ -284,13 +285,27 @@
const type prefix##3 = (pointer)->array[3]; \
GGML_UNUSED(prefix##3);
#define GGML_TENSOR_UNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#ifdef __cplusplus
extern "C" {
#endif
#if defined(__ARM_NEON) && defined(__CUDACC__)
typedef half ggml_fp16_t;
#elif defined(__ARM_NEON)
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
typedef __fp16 ggml_fp16_t;
#else
typedef uint16_t ggml_fp16_t;
@ -330,6 +345,12 @@ extern "C" {
GGML_TYPE_COUNT,
};
// precision
enum ggml_prec {
GGML_PREC_DEFAULT,
GGML_PREC_F32,
};
enum ggml_backend_type {
GGML_BACKEND_CPU = 0,
GGML_BACKEND_GPU = 10,
@ -382,6 +403,7 @@ extern "C" {
GGML_OP_GROUP_NORM,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
GGML_OP_OUT_PROD,
GGML_OP_SCALE,
@ -408,8 +430,10 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
@ -449,7 +473,8 @@ extern "C" {
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
GGML_UNARY_OP_LEAKY
GGML_UNARY_OP_COUNT,
};
enum ggml_object_type {
@ -461,7 +486,8 @@ extern "C" {
enum ggml_log_level {
GGML_LOG_LEVEL_ERROR = 2,
GGML_LOG_LEVEL_WARN = 3,
GGML_LOG_LEVEL_INFO = 4
GGML_LOG_LEVEL_INFO = 4,
GGML_LOG_LEVEL_DEBUG = 5
};
// ggml object
@ -485,7 +511,6 @@ extern "C" {
struct ggml_backend_buffer * buffer;
int n_dims;
int64_t ne[GGML_MAX_DIMS]; // number of elements
size_t nb[GGML_MAX_DIMS]; // stride in bytes:
// nb[0] = ggml_type_size(type)
@ -517,7 +542,7 @@ extern "C" {
void * extra; // extra things e.g. for ggml-cuda.cu
char padding[12];
char padding[8];
};
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@ -622,16 +647,22 @@ extern "C" {
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
GGML_API int ggml_blck_size (enum ggml_type type);
GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
GGML_API int ggml_blck_size(enum ggml_type type);
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
GGML_DEPRECATED(
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
"use ggml_row_size() instead");
GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_quantized(enum ggml_type type);
@ -642,6 +673,11 @@ extern "C" {
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
@ -702,8 +738,8 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
// Context tensor enumeration and lookup
GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
@ -774,6 +810,9 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// dst = a
// view(dst, nb1, nb2, nb3, offset) += b
// return dst
GGML_API struct ggml_tensor * ggml_acc(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -938,15 +977,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_leaky(
GGML_API struct ggml_tensor * ggml_leaky_relu(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * a, float negative_slope, bool inplace);
GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// TODO: double-check this computation is correct
GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);
@ -1028,6 +1066,22 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// change the precision of a matrix multiplication
// set to GGML_PREC_F32 for higher precision (useful for phi-2)
GGML_API void ggml_mul_mat_set_prec(
struct ggml_tensor * a,
enum ggml_prec prec);
// indirect matrix multiplication
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * const as[],
int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
// A: m columns, n rows,
// B: p columns, n rows,
// result is m columns, p rows
@ -1043,13 +1097,13 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
float s);
// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
@ -1235,6 +1289,7 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1283,6 +1338,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// fused soft_max(a*scale + mask)
// mask is optional
GGML_API struct ggml_tensor * ggml_soft_max_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale);
GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1371,8 +1434,13 @@ extern "C" {
int n_dims,
int mode,
int n_ctx,
int n_orig_ctx,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base,
bool xpos_down);
@ -1508,6 +1576,32 @@ extern "C" {
struct ggml_tensor * a,
int scale_factor);
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
GGML_API struct ggml_tensor * ggml_pad(
struct ggml_context * ctx,
struct ggml_tensor * a,
int p0,
int p1,
int p2,
int p3);
// sort rows
enum ggml_sort_order {
GGML_SORT_ASC,
GGML_SORT_DESC,
};
GGML_API struct ggml_tensor * ggml_argsort(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_sort_order order);
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx,
struct ggml_tensor * q,
@ -1569,7 +1663,6 @@ extern "C" {
int kh);
// used in sam
GGML_API struct ggml_tensor * ggml_add_rel_pos(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1744,7 +1837,7 @@ extern "C" {
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads);
GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1);
GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1);
GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads
GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
@ -2040,14 +2133,16 @@ extern "C" {
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id);
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i);
GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i);
// overrides existing values or adds a new one
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
@ -2103,6 +2198,7 @@ extern "C" {
//
GGML_API int ggml_cpu_has_avx (void);
GGML_API int ggml_cpu_has_avx_vnni (void);
GGML_API int ggml_cpu_has_avx2 (void);
GGML_API int ggml_cpu_has_avx512 (void);
GGML_API int ggml_cpu_has_avx512_vbmi(void);

View File

@ -39,19 +39,19 @@ https://huggingface.co/ggerganov/whisper.cpp/tree/main
## Available models
| Model | Disk | Mem | SHA |
| --- | --- | --- | --- |
| tiny | 75 MB | ~390 MB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| tiny.en | 75 MB | ~390 MB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
| base | 142 MB | ~500 MB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| base.en | 142 MB | ~500 MB | `137c40403d78fd54d454da0f9bd998f78703390c` |
| small | 466 MB | ~1.0 GB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| small.en | 466 MB | ~1.0 GB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
| medium | 1.5 GB | ~2.6 GB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| medium.en | 1.5 GB | ~2.6 GB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
| large-v1 | 2.9 GB | ~4.7 GB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
| large-v2 | 2.9 GB | ~4.7 GB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
| large | 2.9 GB | ~4.7 GB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
| Model | Disk | SHA |
| --- | --- | --- |
| tiny | 75 MiB | `bd577a113a864445d4c299885e0cb97d4ba92b5f` |
| tiny.en | 75 MiB | `c78c86eb1a8faa21b369bcd33207cc90d64ae9df` |
| base | 142 MiB | `465707469ff3a37a2b9b8d8f89f2f99de7299dac` |
| base.en | 142 MiB | `137c40403d78fd54d454da0f9bd998f78703390c` |
| small | 466 MiB | `55356645c2b361a969dfd0ef2c5a50d530afd8d5` |
| small.en | 466 MiB | `db8a495a91d927739e50b3fc1cc4c6b8f6c2d022` |
| medium | 1.5 GiB | `fd9727b6e1217c2f614f9b698455c4ffd82463b4` |
| medium.en | 1.5 GiB | `8c30f0e44ce9560643ebd10bbe50cd20eafd3723` |
| large-v1 | 2.9 GiB | `b1caaf735c4cc1429223d5a74f0f4d0b9b59a299` |
| large-v2 | 2.9 GiB | `0f4c8e34f21cf1a914c59d8b3ce882345ad349d6` |
| large-v3 | 2.9 GiB | `ad82bf6a9043ceed055076d0fd39f5f186ff8062` |
## Model files for testing purposes
@ -76,3 +76,27 @@ git clone https://huggingface.co/openai/whisper-medium
# convert the model to ggml
python3 ./whisper.cpp/models/convert-h5-to-ggml.py ./whisper-medium/ ./whisper .
```
## Distilled models
Initial support for https://huggingface.co/distil-whisper is available.
Currently, the chunk-based transcription strategy is not implemented, so there can be sub-optimal quality when using the distilled models with `whisper.cpp`.
```bash
# clone OpenAI whisper and whisper.cpp
git clone https://github.com/openai/whisper
git clone https://github.com/ggerganov/whisper.cpp
# get the models
cd whisper.cpp/models
git clone https://huggingface.co/distil-whisper/distil-medium.en
git clone https://huggingface.co/distil-whisper/distil-large-v2
# convert to ggml
python3 ./convert-h5-to-ggml.py ./distil-medium.en/ ../../whisper .
mv ggml-model.bin ggml-medium.en-distil.bin
python3 ./convert-h5-to-ggml.py ./distil-large-v2/ ../../whisper .
mv ggml-model.bin ggml-large-v2-distil.bin
```

View File

@ -78,14 +78,14 @@ def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
# Ported from models/convert-whisper-to-coreml.py
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
raise ValueError("Invalid model name")
pt_target_path = f"models/hf-{args.model_name}.pt"

View File

@ -296,13 +296,13 @@ def convert_decoder(hparams, model, quantize=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()

View File

@ -38,10 +38,10 @@ def convert_encoder(hparams, encoder, mname):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1, large-v2)", required=True)
parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3)", required=True)
args = parser.parse_args()
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3"]:
raise ValueError("Invalid model name")
whisper = load_model(args.model).cpu()

Some files were not shown because too many files have changed in this diff Show More