Compare commits

..

167 Commits

Author SHA1 Message Date
2e59dced12 whisper : rename binaries + fix install (#2648)
* whisper : rename binaries + fix install

* cont : try to fix ci

* cont : fix emscripten builds
2024-12-21 09:43:49 +02:00
e4e05981d6 ruby : update gem version to v1.3.1
Some checks failed
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / quantize (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Has been cancelled
2024-12-20 11:53:27 +02:00
3de9deead5 release : v1.7.3
Some checks failed
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / quantize (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Has been cancelled
2024-12-18 18:12:40 +02:00
47f989f9b3 ci : msys enable SDL2 build (#2635)
Some checks failed
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run
Bindings Tests (Ruby) / ubuntu-latest (push) Has been cancelled
2024-12-18 12:52:41 +02:00
acc4e13dee ruby : sync ggml (#2643) 2024-12-18 12:52:16 +02:00
ba6c2a8fd9 android : try to fix build 2024-12-18 12:52:16 +02:00
6576af00d7 files : remove old sources 2024-12-18 12:52:16 +02:00
8ac5db0169 sync : ggml 2024-12-18 12:52:16 +02:00
61edb117a0 talk-llama : sync llama.cpp 2024-12-18 12:52:16 +02:00
eb97b257eb sync : ggml 2024-12-18 12:52:16 +02:00
479499dc0e ggml : update ggml_backend_cpu_device_supports_op (llama/10867)
* ggml : fix cpy op for IQ-quants to use reference impl

ggml-ci

* ggml : disable tests involving i-matrix quantization

* ggml : update ggml_backend_cpu_device_supports_op

ggml-ci
2024-12-18 12:52:16 +02:00
Eve
d420a759c5 vulkan: bugfixes for small subgroup size systems + llvmpipe test (llama/10809)
* ensure mul mat shaders work on systems with subgroup size less than 32

more fixes

add test

* only s_warptile_mmq needs to be run with 32 threads or more
2024-12-18 12:52:16 +02:00
a1ab9b5e91 rwkv6: add wkv6 support for Vulkan backend (llama/10829)
* rwkv_wkv6 vulkan shader

* RWKV_WKV6 Vulkan op tests passed

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* Apply code format changes

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* add [[unroll]] and remove unnecessary conditions

* add uma support

* fix erros in EditorConfig Checker

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
Co-authored-by: Molly Sophia <mollysophia379@gmail.com>
2024-12-18 12:52:16 +02:00
e22d38e4f2 llama : add Qwen2VL support + multimodal RoPE (llama/10361)
* Barebone Qwen2VL LLM convertor

* Add Qwen2VL cli entrypoint

* [WIP] add qwen2vl arch

* Verify m-rope output

* Add vl-rope/2d-rope support for qwen2vl ViT

* update qwen2vl cli tool

* update 5D tensor op workaround

* [WIP] qwen2vl vision model

* make batch and clip utils compatible with qwen2vl

* [WIP] create inference workflow, gguf convert script but fix

* correcting vision-rope behavior, add the missing last layer back to ViT

* add arg parser to qwen2vl_surgery

* replace variable size array with vector

* cuda-gdb cmake preset

* add fp32 mrope, vision rope kernel

* add fp16 support for qwen2vl and m-rope

* add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION`

* fix rope op mode switching, out dated func args

* update `llama_hparams`

* update to keep up stream changes

* resolve linter, test errors

* add makefile entry, update speical image padding token

* add mrope unit test, fix few compiler warnings

* rename `mrope` related function, params

* minor updates on debug util, bug fixs

* add `m-rope` testcase to `test-backend-ops`

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix traililng whitespce

* store `llama_hparams.rope_sections` with fixed size array

* update position id tensor size check in GGML_OP_ROPE

* minor updates

* update `ggml_backend_*_supports_op` of unsupported backends

* remote old `rope_section` compare operator

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-12-18 12:52:16 +02:00
856fbaa92f Introducing experimental OpenCL backend with support for Qualcomm Adreno GPUs (llama/10693)
* [cl][adreno] Add Adreno GPU support

Add new OpenCL backend to support Adreno GPUs

---------

Co-authored-by: Skyler Szot <quic_sszot@quicinc.com>
Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com>
Co-authored-by: Alexander Angus <quic_aangus@quicinc.com>
Co-authored-by: Hongqiang Wang <quic_wangh@quicinc.com>
Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com>

* [cl][ci] Add workflow for CL

* [cl][adreno] Fix memory leak for non SMALL_ALLOC path

* opencl: integrate backend dyn.load interface and fix compiler and format warnings

* opencl: remove small-alloc support and fix build errors for non-opencl platforms

* opencl: fixed merge conflict (MUSA added twice in cmake)

* opencl-ci: use RUNNER_TEMP instead of github.workspace

* opencl: fix embed tool invocation with python3

* opencl: CI workflow fixes

* opencl: Clean up small-alloc in CMake files

* opencl: cleanup ggml-opencl2 header file

* opencl: use ulong for offsets and strides in ADD kernel

* opencl: use cl_ulong for all offsets

* opencl: use cl_ulong for sizes and strides

* opencl: use `GGML_LOG_xxx` instead of `fprintf(stderr, ...)`

* opencl: rename backend `opencl2` -> `opencl`

* opencl: rename kernel files `ggml-opencl2` -> `ggml-opencl`

* opencl: make OpenCL required, remove redundant lib and inc directories

* `ggml-base`, `..` and `.` are added by `ggml_add_backend_library`

* opencl: rename backend - funcs, structs, etc `opencl2` -> `opencl`

* opencl: remove copyright marker since main license already covers

* opencl: replace some more OPENCL2 leftovers

* opencl: remove limits on `tensor_extra`

* opencl: use pools for `tensor_extra`

* opencl: fix compiler warnings with GCC and Clang

Still getting the warning about clCreateCmdQueue being obsolete.
Will fix that separately.

* opencl: fail gracefully if opencl devices are not available

Also for unsupported GPUs.

* opencl: fix MSVC builds (string length error)

* opencl: check for various requirements, allow deprecated API

* opencl: update log message for unsupported GPUs

---------

Co-authored-by: Skyler Szot <quic_sszot@quicinc.com>
Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com>
Co-authored-by: Alexander Angus <quic_aangus@quicinc.com>
Co-authored-by: Hongqiang Wang <quic_wangh@quicinc.com>
Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com>
2024-12-18 12:52:16 +02:00
2c05efa4b1 Fix crash caused by ggml_backend_load_all when launching on Android Activity (llama/10812)
* Fix crash caused by ggml_backend_load_all when launching on AndroidActivity.

Details:
Calling ggml_backend_load_all during initialization in the AndroidActivity project leads to a crash with the error:
terminating with uncaught exception of type std::__ndk1::__fs::filesystem::filesystem_error: filesystem error: in directory_iterator::directory_iterator(...): Permission denied [./].
This issue occurs because AndroidActivity restricts file access due to sandboxing.

Reproduction:
In the example folder, the LlamaAndroid project can reproduce the crash by calling ggml_backend_load_all first in Java_android_llama_cpp_LLamaAndroid_backend_1init.

* Update ggml/src/ggml-backend-reg.cpp

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2024-12-18 12:52:16 +02:00
Eve
c21fb10b28 vulkan: small mul_mat_vec optimizations (llama/10665)
* double the number of rows per workgroup

* Update ggml-vulkan.cpp

* Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats

* only increase the number of rows for amd and subgroup size 64

* fix missing NUM_ROWS for mul_mat_vec_iq4_nl_f16_f32, untested

* use subgroup min and max to check for gcn (requires https://github.com/ggerganov/llama.cpp/pull/10721)

* manual merge ggml-vulkan.cpp

* set min and max subgroup size in any case

* Also double the number of rows for Intel GPUs
2024-12-18 12:52:16 +02:00
26c9fd0cdc SYCL: Reduce most of the compiler warnings (llama/10748)
* Try to reduce some unused and typecast warnings

* Reduce compiler warnings step 2

* add a newline at the end of the file

* Initialize nreduce as size_t

* [SYCL] Remove pragma directives from mmq.cpp

* SYCL: mmq add condition to prevent blocks_per_tile_x_row variable from becoming 0

* SYCL softmax: Initialize nreduce as size_t

* ggml-sycl.cpp: fix some trailing whitespaces

* SYCL: remove the unused variables instead of commenting it out

* SYCL poo2d kernel: set NAN for invalid pooling op

* SYCL gemm.hpp: remove pragma directives

* SYCL gemm.hpp: use const cast to properly support dnnl::memory

* SYCL: wkv6 remove a comment

* SYCL: clean comments step 2

* SYCL: clean comments and variables step 3

* SYCL: Use GGML_UNUSED for unused variables

* SYCL: remove extra empty lines and a comment

* Remove TODO

* cleanup spaces

* add a stdout for unsupported op

* use sycl printf over fprintf

* remove prints for CI

* SYCL ggml-sycl: pool2D use sycl::nan and remove if-else block

---------

Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
2024-12-18 12:52:16 +02:00
e6eed605cf ggml : Fix compilation issues on ARM platform when building without fp16 (llama/10811) 2024-12-18 12:52:16 +02:00
abe3102cb7 CUDA: faster non-contiguous concat (llama/10760)
* faster uncontiguous concat

* Use a lambda to avoid code duplication

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* Update ggml/src/ggml-cuda/concat.cu

* add constexpr  and static assert

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2024-12-18 12:52:16 +02:00
1193e494a9 remove CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS (llama/10797)
other windows build fixes
2024-12-18 12:52:16 +02:00
e5e951672e Vulkan: Use improved q4_k and q5_k dequant code in dequant shaders (llama/10798) 2024-12-18 12:52:16 +02:00
0e24559ad9 Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats (llama/10721)
* Vulkan: Add VK_EXT_subgroup_size_control support to ensure full subgroups for coopmats

* Fix subgroup size control extension support check

Add accf32 and accf16 checks for coopmats

* Also disable coopmats on amdvlk
2024-12-18 12:52:16 +02:00
527ac800cf ggml: load all backends from a user-provided search path (llama/10699)
* feat: load all backends from a user-provided search path

* fix: Windows search path

* refactor: rename `ggml_backend_load_all_in_search_path` to `ggml_backend_load_all_from_path`

* refactor: rename `search_path` to `dir_path`

* fix: change `NULL` to `nullptr`

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* fix: change `NULL` to `nullptr`

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2024-12-18 12:52:16 +02:00
479bd77169 vulkan: request round-to-even for fp16 in im2col/rope_head (llama/10767)
Vulkan doesn't mandate a specific rounding mode, but the shader_float_controls
feature allows rounding mode to be requested if the implementation supports it.
2024-12-18 12:52:16 +02:00
Eve
d8bf63a41b vulkan: dynamic subgroup size for the remaining k quants (llama/10745)
* q5_k

q4_k

q3_k

q2_k

q6_k multi row example

* revert as multi row isnt faster for k quants
2024-12-18 12:52:16 +02:00
b82c8d76dc CUDA: rename macros to avoid conflicts with WinAPI (llama/10736)
* Renames NVIDIA GPU-architecture flags to avoid name clashes with WinAPI. (e.g. CC_PASCAL, GPU architecture or WinAPI pascal compiler flag?)

* Reverts erroneous rename in SYCL-code.

* Renames GGML_CUDA_MIN_CC_DP4A to GGML_CUDA_CC_DP4A.

* Renames the rest of the compute capability macros for consistency.
2024-12-18 12:52:16 +02:00
86346f811e vulkan: disable spirv-opt for coopmat shaders (llama/10763)
There are some bugs in the 1.3.296 SDK, so disable this. It isn't strictly
necessary anyway.

Add missing dependency on vulkan-shaders-gen, so shaders get recompiled when it
changes.

Fix coopmat support reporting when glslc doesn't support NV_coopmat2.
2024-12-18 12:52:16 +02:00
c635f40a34 ggml : remove return from ggml_gallocr_allocate_node (ggml/1048)
This commit removes the return statement from ggml_gallocr_allocate_node
function.

The motivation behind this change is to make the code more readable and
consistent.
2024-12-18 12:52:16 +02:00
e0be0de1ee ggml : add check for grad_accs (ggml/1046)
* ggml : add check for grad_accs

This commit adds a check for grad_accs in ggml_graph_get_grad and
ggml_graph_get_grad_acc functions. This is necessary to avoid segfaults
when grad_accs is not initialized.

The motivation for this change is that I find it nice to be able to
print out a computation graph using ggml_graph_print but this function
segfaults when grad_accs is not initialized:
```console
(gdb) p g1
$2 = (ggml_cgraph *) 0x7ffff66004b0
(gdb) p *g1
$3 = {size = 2048, n_nodes = 1, n_leafs = 2, nodes = 0x7ffff6600500,
grads = 0x0, grad_accs = 0x0, leafs = 0x7ffff6604500,
visited_hash_set = {size = 4099, used = 0x7ffff6610518,
keys = 0x7ffff6608500}, order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT}
(gdb) p ggml_graph_print(g1)
=== GRAPH ===
n_nodes = 1

Program received signal SIGSEGV, Segmentation fault.
0x0000555555579775 in ggml_graph_get_grad
(cgraph=0x7ffff66004b0,node=0x7ffff6600340)
    at /ggml/ggml/src/ggml.c:5990
5990  return igrad != GGML_HASHSET_FULL &&
          ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ?
          cgraph->grads[igrad] : NULL;
```

* squash! ggml : add check for grad_accs

Fix the check in ggml_graph_get_grad. The check was incorrectly using
cgraph->grad_accs instead of cgraph->grads.
2024-12-18 12:52:16 +02:00
60dc6d003f common : remove old types
ggml-ci
2024-12-18 12:52:16 +02:00
eb27e0d834 CUDA: fix shared memory access condition for mmv (llama/10740) 2024-12-18 12:52:16 +02:00
a682fdce0c vulkan: fix compile warnings (llama/10731) 2024-12-18 12:52:16 +02:00
9ffbd3d969 Vulkan: fix NaN in tanh.comp with AMD proprietary driver on Windows (llama/10723)
* Vulkan: fix NaN in tanh.comp

* Faster NaN-free tanh
2024-12-18 12:52:16 +02:00
6585a890b4 vulkan: compile a test shader in cmake to check for coopmat2 support (llama/10713) 2024-12-18 12:52:16 +02:00
d0a050b51f ggml : disable iq4_nl interleave size 8 (llama/10709)
ggml-ci
2024-12-18 12:52:16 +02:00
e990d1b791 ggml : refactor online repacking (llama/10446)
* rename ggml-cpu-aarch64.c to .cpp

* reformat extra cpu backend.

- clean Q4_0_N_M and IQ4_0_N_M
  - remove from "file" tensor type
  - allow only with dynamic repack

- extract cpu extra bufts and convert to C++
  - hbm
  - "aarch64"

- more generic use of extra buffer
  - generalise extra_supports_op
  - new API for "cpu-accel":
     - amx
     - aarch64

* clang-format

* Clean Q4_0_N_M ref

Enable restrict on C++

* add op GGML_OP_MUL_MAT_ID for Q4_0_N_M with runtime repack

* added/corrected control on tensor size for Q4 repacking.

* Update ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add debug logs on repacks.

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-12-18 12:52:16 +02:00
4a6d52efe6 Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (llama/10597)
* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

* Improve performance with better q4_k and q5_k dequant and store unrolling

* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection

* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device

* Vulkan: Implement accumulator switch for specific mul mat mat shaders

* Vulkan: Unroll more loops for more mul mat mat performance

* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic

* Disable coopmat support on AMD proprietary driver

* Remove redundant checks

* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support

* Fix rebase typo

* Fix coopmat2 MUL_MAT_ID pipeline selection
2024-12-18 12:52:16 +02:00
8b841d430a metal : Extend how Llama.cpp locates metal resources (llama/10676)
* metal : Extend how Llama.cpp locates metal resources (llama/10675)

  * It searches the resource file in the directory where the current
    binary is located as well.
  * Resolves symbolic links.

Rationale:

When we plug this dependency into a Bazel build and run it in the
context of Bazel (e.g. testing):

  * the execution directory is often very different from where the files
    are located and no direct control over this (Bazel sandboxing),
  * the Bazel sandbox often use symbolic links to make files available.

With this patch, we can have the resource file added to the target,
can build and run tests in the context of Bazel.

* Update ggml/src/ggml-metal/ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-metal/ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-12-18 12:52:16 +02:00
b74b68212a vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (llama/10206) 2024-12-18 12:52:16 +02:00
3a27b2b91b ruby : Add no_speech_thold (#2641)
* Remove Whisper::Model.[]

* Fix Whisper::Model::URI#request

* Make Whisper::Context#initialize accept pre-converted model name

* Use downloading pre-converted model feature for testing

* Update README

* Remove unnecessary task

* Move whisper/model.rb -> whisper/model/uri.rb

* Update document comment of Whisper::Context#initialize

* Don't show download progress when not tty

* Pass String to raise

* Use cache model file if download fails

* Add test for auto download

* Specify required Ruby version

* Fix a typo

* Remove unnecessary flags

* Initialize Whisper::Params#diarize explicitely

* Remove redundant code from README for simplicity

* Add Whisper::Params#no_speech_thold attribute

* Add test for Whisper::Params#no_speech_thold
2024-12-18 11:00:50 +02:00
d34445e960 stream : improve consistency in README (#2642) 2024-12-18 08:43:48 +02:00
f897eb7670 whisper : support no_speech_thold (#2625)
Some checks are pending
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run
* Implement no_speech_thold

no_speech_thold functionality is on par with OpenAI's whisper

* Addressed review comments
2024-12-17 19:15:47 +02:00
2f2841bfce whisper : add single-timestamp logic (#2629)
* Fix hallucinations during silence

When the predicted tokens end with a single timestamp the the entire 30 segment should be considered as done, to avoid hallucinations for the remaining part of segment.
This behaviour is on par with openai's whisper. Refer to logic related to `single_timestamp_ending` in https://github.com/openai/whisper/blob/main/whisper/transcribe.py

* Accept review comments related to formatting.

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-12-17 19:07:08 +02:00
09a1b61218 readme : fix typo (#2637) 2024-12-17 19:05:35 +02:00
94e7da1ff2 cmake : fix "amd64" processor string (#2638) 2024-12-17 18:34:32 +02:00
c4aed6831e vulkan : fix soft_max.comp division by zero (#2633)
Some checks failed
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / quantize (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Has been cancelled
This change prevents a division by zero error when p.KY is 0.
2024-12-16 12:34:38 +02:00
199579652e common : add cstdio header 2024-12-16 08:57:04 +02:00
d17e7139d8 stream : update build instructions
Some checks are pending
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run
2024-12-15 21:55:36 +02:00
6a52eaea74 android : fix build and ci (#2624)
Some checks failed
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / quantize (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Has been cancelled
* Adding missing CMakeLists.txt include for ggm-cpu needed by whisper.android

* attempt to re-enable CI for JNI android

---------

Co-authored-by: Your Name <you@example.com>
2024-12-14 17:25:53 +02:00
6aa1d7b892 models : fix typo in download-ggml-model.sh (#2623)
Some checks failed
CI / ubuntu-latest-gcc (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / quantize (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Has been cancelled
Introduced in #2589
2024-12-12 18:02:00 +02:00
262e865a70 ruby : Sync whisper.cpp and model download feature (#2617)
* Use C++17

* Add test for Pathname of model

* Make Whisper::Context#initialize accept Pathname

* Add shorthand for pre-converted models

* Update documents

* Add headings to API section in README [skip ci]

* Remove unused function

* Don't care about no longer included file

* Cosmetic fix

* Use conditional get when get model files
2024-12-09 13:17:50 +02:00
ed733e85a1 scripts : update to new build system 2024-12-09 11:30:16 +02:00
5980b1ae77 devops : add cmake 2024-12-08 23:09:26 +02:00
0415a66044 devops : update make commands 2024-12-08 23:07:29 +02:00
7d134e3737 ggml : remove old files (skip) (#0) 2024-12-08 23:04:26 +02:00
9df53b357e ggml : sync remnants (skip) (#0) 2024-12-08 22:48:25 +02:00
b2115b4d9b scripts : remove amx from sync 2024-12-08 22:48:14 +02:00
0164427dd5 ci : disable freeBSD builds [no ci] 2024-12-08 20:14:35 +02:00
627b11c78a readme : update build instructions 2024-12-08 20:14:35 +02:00
472464453d ci : disable CUDA and Android builds 2024-12-08 20:14:35 +02:00
11dddfbc9e ci : disable Obj-C build + fixes 2024-12-08 20:14:35 +02:00
384e214cc7 make : shim cmake 2024-12-08 20:14:35 +02:00
f2c680f893 talk-llama : sync llama.cpp 2024-12-08 20:14:35 +02:00
fbe66da0e5 sync : ggml 2024-12-08 20:14:35 +02:00
a815940e0e ggml : add predefined list of CPU backend variants to build (llama/10626)
* ggml : add predefined list of CPU backend variants to build

* update CPU dockerfiles
2024-12-08 20:14:35 +02:00
904e307bce ggml-cpu : fix HWCAP2_I8MM value (llama/10646) 2024-12-08 20:14:35 +02:00
491ec076b4 vulkan: Implement "fast divide" (mul+shift) for unary ops like copy (llama/10642) 2024-12-08 20:14:35 +02:00
966433fdf2 SYCL : Move to compile time oneMKL interface backend selection for NVIDIA backend (llama/10584)
* [SYCL] Move to Compile Time backend selection on oneMKL Interface for NVIDIA backend

Move to compile time selection to backend to avoid latency at run time.
Add it to all mkl gemm calls and only for NVIDIA backend.

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>

* Formatting

* Address PR comments to increase readibility

---------

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
2024-12-08 20:14:35 +02:00
6f1ba9d82d Avoid using __fp16 on ARM with old nvcc (llama/10616) 2024-12-08 20:14:35 +02:00
015ecd0001 vulkan: optimize and reenable split_k (llama/10637)
Use vector loads when possible in mul_mat_split_k_reduce. Use split_k
when there aren't enough workgroups to fill the shaders.
2024-12-08 20:14:35 +02:00
PAB
b7c64a4352 ggml: add GGML_SET Metal kernel + i32 CPU kernel (ggml/1037)
* implemented cpu kernel

* add i32 test cases in test-backend-ops

* typedef `ggml_metal_kargs_set`

* implemented `kernel_set`

* memcpy
2024-12-08 20:14:35 +02:00
PAB
7895d39508 ggml : add GGML_PAD_REFLECT_1D operation (ggml/1034)
* ggml_pad_reflect_1d defined in header

* implemented on CPU

* called the forward pass

* impl Metal kernel

* added Metal kernel

* added OP_PAD_REFLECT_1D in test-backend-ops.cpp

* add test-pad-reflect-1d test case

* test case support multiple backend
2024-12-08 20:14:35 +02:00
22616f00f9 files : remove make artifacts 2024-12-08 20:14:35 +02:00
02c6fcbc2c common : fix compile warning
ggml-ci
2024-12-08 20:14:35 +02:00
3daeacad24 ggml : move AMX to the CPU backend (llama/10570)
ggml : automatic selection of best CPU backend (llama/10606)
2024-12-08 20:14:35 +02:00
4d73962da4 metal : small-batch mat-mul kernels (llama/10581)
* metal : small-batch mat-mul kernels

ggml-ci

* metal : add rest of types

ggml-ci

* metal : final adjustments

ggml-ci

* metal : add comments

ggml-ci
2024-12-08 20:14:35 +02:00
068812650e SYCL: Fix and switch to GGML_LOG system instead of fprintf (llama/10579)
* Switched to GGML_LOG

* Fix missing semicolon
2024-12-08 20:14:35 +02:00
4b7e059e15 ggml-cpu: replace AArch64 NEON assembly with intrinsics in ggml_gemv_q4_0_4x4_q8_0() (llama/10567)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2024-12-08 20:14:35 +02:00
Eve
30e35d7271 vulkan: Dynamic subgroup size support for Q6_K mat_vec (llama/10536)
* subgroup 64 version with subgroup add. 15% faster

scalable version

tested for subgroup sizes 16-128

* check for subgroup multiple of 16 and greater than 16

* subgroup sizes are always a power of 2 (https://github.com/KhronosGroup/GLSL/issues/45)

* force 16 sequential threads per block

* make 16 subgroup size a constant
2024-12-08 20:14:35 +02:00
3623bd58f2 ggml : fix I8MM Q4_1 scaling factor conversion (llama/10562)
ggml-ci
2024-12-08 20:14:35 +02:00
cb847c20a7 ggml-cpu: fix typo in gemv/gemm iq4_nl_4_4 (llama/10580) 2024-12-08 20:14:35 +02:00
964b154a2a sycl : offload of get_rows set to 0 (llama/10432) 2024-12-08 20:14:35 +02:00
d7c2a04bce sycl : Reroute permuted mul_mats through oneMKL (llama/10408)
This PR fixes the failing MUL_MAT tests for the sycl backend.
2024-12-08 20:14:35 +02:00
2bb4ca9cba CANN: RoPE operator optimization (llama/10563)
* [cann] RoPE operator optimization

* [CANN]Code Formatting

---------

Co-authored-by: noemotiovon <noemotiovon@gmail.com>
2024-12-08 20:14:35 +02:00
a753a82462 vulkan: get the first command buffer submitted sooner (llama/10499)
This is an incremental improvement over #9118 to get work to the GPU a bit
sooner. The first part is to start with a smaller number of nodes before
the first submit, and ramp it up to the current 100 nodes/submit. The
second part is to reduce the dryrun overhead for all the nodes that just
need to request descriptor space.

With these changes I get around 1-2% speedup on RTX 4070 combined with my
old Haswell-era CPU.
2024-12-08 20:14:35 +02:00
276b08d8f0 ggml : remove redundant copyright notice + update authors 2024-12-08 20:14:35 +02:00
4ca1e72fe0 ggml : fix row condition for i8mm kernels (llama/10561)
ggml-ci
2024-12-08 20:14:35 +02:00
16a66f103f cmake : fix ARM feature detection (llama/10543)
ggml-ci
2024-12-08 20:14:35 +02:00
330273901f ggml-cpu: support IQ4_NL_4_4 by runtime repack (llama/10541)
* ggml-cpu: support IQ4_NL_4_4 by runtime repack

* ggml-cpu: add __ARM_FEATURE_DOTPROD guard
2024-12-08 20:14:35 +02:00
42099a9342 kompute : improve backend to pass test_backend_ops (llama/10542)
* kompute: op_unary: reject unsupported parameters

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: softmax: implement ALiBi support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: rope: implement neox and phi3 support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_q4_k permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_[q4_0|q4_1|q8_0] permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_f16 permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_q6_k permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

---------

Signed-off-by: Sergio Lopez <slp@redhat.com>
2024-12-08 20:14:35 +02:00
90dd5fca9c CANN: Fix SOC_TYPE compile bug (llama/10519)
* CANN: Fix the bug build fail on Ascend310P under two cases:
1) Manual specify SOC_TYPE
2) Under some unusual compile environment

* Update the cann backend News content: Support F16 and F32 data type model for Ascend 310P NPU.

* fix CANN  compile fail bug: the assert in ascend kernel function doesn't supportted on some CANN version
2024-12-08 20:14:35 +02:00
2490f2a7f8 CANN: ROPE operator optimization (llama/10540)
* [cann] ROPE operator optimization

Co-authored-by: noemotiovon <noemotiovon@gmail.com>
2024-12-08 20:14:35 +02:00
230e985633 Add some minimal optimizations for CDNA (llama/10498)
* Add some minimal optimizations for CDNA

* ggml_cuda: set launch bounds also for GCN as it helps there too
2024-12-08 20:14:35 +02:00
ae24083f23 metal : fix group_norm support condition (llama/0) 2024-12-08 20:14:35 +02:00
6463e36369 vulkan: define all quant data structures in types.comp (llama/10440) 2024-12-08 20:14:35 +02:00
b3301f7d82 vulkan: Handle GPUs with less shared memory (llama/10468)
There have been reports of failure to compile on systems with <= 32KB
of shared memory (e.g. #10037). This change makes the large tile size
fall back to a smaller size if necessary, and makes mul_mat_id fall
back to CPU if there's only 16KB of shared memory.
2024-12-08 20:14:35 +02:00
ab5d4d93ec vulkan: further optimize q5_k mul_mat_vec (llama/10479) 2024-12-08 20:14:35 +02:00
2d6e9dd723 vulkan: skip integer div/mod in get_offsets for batch_idx==0 (llama/10506) 2024-12-08 20:14:35 +02:00
2f16e51553 vulkan: optimize Q2_K and Q3_K mul_mat_vec (llama/10459) 2024-12-08 20:14:35 +02:00
0f0994902f mtgpu: Add MUSA_DOCKER_ARCH in Dockerfiles && update cmake and make (llama/10516)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
2024-12-08 20:14:35 +02:00
5e1fcc1780 vulkan: fix group_norm (llama/10496)
Fix bad calculation of the end of the range. Add a backend test that
covers the bad case (taken from stable diffusion).

Fixes https://github.com/leejet/stable-diffusion.cpp/issues/439.
2024-12-08 20:14:35 +02:00
48f421de23 cmake : enable warnings in llama (llama/10474)
* cmake : enable warnings in llama

ggml-ci

* cmake : add llama_get_flags and respect LLAMA_FATAL_WARNINGS

* cmake : get_flags -> ggml_get_flags

* speculative-simple : fix warnings

* cmake : reuse ggml_get_flags

ggml-ci

* speculative-simple : fix compile warning

ggml-ci
2024-12-08 20:14:35 +02:00
e7afb2b991 ggml-cpu: cmake add arm64 cpu feature check for macos (llama/10487)
* ggml-cpu: cmake add arm64 cpu feature check for macos

* use vmmlaq_s32 for compile option i8mm check
2024-12-08 20:14:35 +02:00
9a5ef7b169 CANN: Improve the Inferencing Performance for Ascend NPU Device (llama/10454)
* improve inferencing performance for ascend npu.

Co-authored-by: Frank Mai <thxCode@thxcode0824@gmail.com>

* some modification after review

* some modifications after review

* restore some modifications

* restore some modifications

---------

Co-authored-by: shanshan shen <shanshanshen333@gmail.com>
Co-authored-by: Frank Mai <thxCode@thxcode0824@gmail.com>
2024-12-08 20:14:35 +02:00
453cc0fcf1 CANN: RoPE and CANCAT operator optimization (llama/10488)
Co-authored-by: noemotiovon <noemotiovon@gmail.com>
2024-12-08 20:14:35 +02:00
78dfec6bc5 vulkan: Fix a vulkan-shaders-gen arugment parsing error (llama/10484)
The vulkan-shaders-gen was not parsing the --no-clean argument correctly.
Because the previous code was parsing the arguments which have a value only
and the --no-clean argument does not have a value, it was not being parsed
correctly. This commit can now correctly parse arguments that don't have values.
2024-12-08 20:14:35 +02:00
f6d518fc4c metal : enable mat-vec kernels for bs <= 4 (llama/10491) 2024-12-08 20:14:35 +02:00
ac33379a35 llama : accept a list of devices to use to offload a model (llama/10497)
* llama : accept a list of devices to use to offload a model

* accept `--dev none` to completely disable offloading

* fix dev list with dl backends

* rename env parameter to LLAMA_ARG_DEVICE for consistency
2024-12-08 20:14:35 +02:00
77e3e4a090 ggml : add support for dynamic loading of backends (llama/10469)
* ggml : add support for dynamic loading of backends

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-12-08 20:14:35 +02:00
b840bb09be metal : minor code formatting 2024-12-08 20:14:35 +02:00
8b1c1c30a7 ggml : do not use ARM features not included in the build (llama/10457) 2024-12-08 20:14:35 +02:00
4b81335f75 CANN: Support Ascend310P to accelerate F32 and F16 Model (llama/10216)
* CANN Support Ascend310P to accelerate F32 and F16 Model

* Add compile option soc type macro ASCEND_310P to ggml-cann lib

* Remove unused code

* Remove the ascend soc_type hard code compile option in CMakelist.txt
2024-12-08 20:14:35 +02:00
2a4b5c9d7e cuda : optimize argmax (llama/10441)
* cuda : optimize argmax

* remove unused parameter

ggml-ci

* fixup : use full warps

ggml-ci

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* fix ub

* ggml : check ne00 <= INT32_MAX in argmax and argsort

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-12-08 20:14:35 +02:00
04662748aa vulkan: predicate max operation in soft_max shaders/soft_max (llama/10437)
Fixes #10434
2024-12-08 20:14:35 +02:00
a117279e13 vulkan: copy iq4_nl LUT into shared memory (llama/10409) 2024-12-08 20:14:35 +02:00
bbb292ed38 vulkan: further optimize mul_mat_vec using larger loads (llama/10387)
* vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec.

Add some early returns for nonexistent rows in mul_mat_vec shaders. These
can only be hit when dispatching a 2D grid of workgroups. Fix the logic
for the 2D grid of workgroups to round up.

Enable the pipeline robustness extension if it's available, and use it to
disable robustness for these pipelines. The instructions to do the bounds
checking contend for the same ALU resources as the bit twiddling dequant
instructions.

* vulkan: Add GLSL structure aliases for quant types to allow larger loads

In Vulkan it's not possible to cast pointer types, so instead you have to
declare an aliased binding for the memory with a different type. This
commit adds aliases for the quant formats using 16b ints, and in a few
places where the struct size is a multiple of 4 also using 32b ints.
Currently only q4_k's aliases are used, but others will be used in
subsequent commits.

* vulkan: use larger loads in q5_k and q6_k shaders.

Similar to the optimization I did in q4_k recently, this vectorizes some loads
and reduces the number of bit twiddling instructions.

* vulkan: use larger K step per iteration in mul_mat_vec.

Add vec4 dequantization functions, and use them to do K=8 per iteration in
mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B
which helps reduce the load on the memory system.

The K_PER_ITER==2 logic is still there, just for F16/F32, and really only
because they support unaligned sizes.

Tweak the num_iters/unrolling logic to be simpler and catch a couple missed
unrolling opportunities.
2024-12-08 20:14:35 +02:00
95e8901e71 add cmake rvv support (llama/10411) 2024-12-08 20:14:35 +02:00
4af9626702 CUDA: remove unnecessary warp reduce in FA (ggml/1032)
* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit

* same problem in vec32

---------

Co-authored-by: ZhaoXiaoYu <zhao.xiaoyu@zte.com.cn>
2024-12-08 20:14:35 +02:00
PAB
c52d1035de feat: add GGML_UNARY_OP_ARGMAX Metal kernel (ggml/1019)
* implemented argmax kernel

* tpig -> tgpig

* change to strides

* contiguous assertions

* kernel working and tested

* argmax simd parallel implementation

* added 2 new tests for argmax in test-backend-ops

* cosmit

* added 3 tests cases for perf eval

* add test_argmax in make_test_cases_perf

* Update test-backend-ops.cpp

Co-authored-by: Diego Devesa <slarengh@gmail.com>

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2024-12-08 20:14:35 +02:00
PAB
5773a14980 metal : add GGML_OP_CONV_TRANSPOSE_1D kernels (ggml/1026)
* wip

* wip implementation f32

* kernel conv transpose 1d f32 working

* initial commit
2024-12-08 20:14:35 +02:00
6939147c47 Do not include arm_neon.h when compiling CUDA code (ggml/1028) 2024-12-08 20:14:35 +02:00
98f9916c9f ggml-opt: fix data corruption (ggml/1022) 2024-12-08 20:14:35 +02:00
021eef1000 ruby : Add low-level methods to transcribe (#2585)
* Add tests for Whisper::Context#full

* Add Whisper::Context#full

* Add tests for Whisper::Error

* Add document of Whisper::Context#full [skip ci]

* Add additional signature for Whisper::Context#full

* Add description to Whisper::Context#full

* Add test for Whisper::Context#full_parallel

* Add Whisper::Context#full_parallel

* Hide Whisper's instance methods from Ruby code

* Add class to test MemoryView

* Build test class before running test

* Add test for MemoryView

* Make Whisper::Context#full and #full_parallel accept MemoryView

* Use Ruby 3.1 on CI

* Add comment on samples data type

* Update README

* Update README

* Remove unused code
2024-11-28 10:33:07 +02:00
a9d06ce151 models : add q8_0 models to download-ggml-model.sh (#2589) 2024-11-28 10:31:54 +02:00
8c6a9b8bb6 ruby : Follow source tree change (#2580)
* Follow whisper.cpp source tree change

* Update whispercpp.gemspec

* Follow whisper.cpp log level change

* Fix paths in GitHub workflow for Ruby bindings

* Use GitHub workflow setting for dependency definition

* Use ternary operator
2024-11-21 17:04:29 +02:00
37c88027e1 whisper : use backend registry (#0) 2024-11-20 21:00:08 +02:00
9db070a3c5 ggml/sched : do not skip views in pre-assignments 2024-11-20 21:00:08 +02:00
7fd8d9c220 whisper : adapt to new ggml (wip) 2024-11-20 21:00:08 +02:00
06e059b8f8 talk-llama : sync llama.cpp 2024-11-20 21:00:08 +02:00
c9f49d5f9d sync : ggml 2024-11-20 21:00:08 +02:00
f4c1d7df39 ggml : sync resolve (skip) (#0) 2024-11-20 21:00:08 +02:00
339b8e559c Add required ggml-base and backend libs to cmake pkg (llama/10407) 2024-11-20 21:00:08 +02:00
5f6d6919b4 cuda : fix CUDA_FLAGS not being applied (llama/10403) 2024-11-20 21:00:08 +02:00
8ee767732f sycl : Add option to set the SYCL architecture for all targets (llama/10266)
* Add option to set the SYCL architecture for all targets
* Convert GGML_SYCL_HIP_TARGET to the more generic GGML_SYCL_ARCH option
* Document that setting GGML_SYCL_ARCH can improve the performance
2024-11-20 21:00:08 +02:00
45f1f9144f vulkan: Optimize soft_max (llama/10301)
* vulkan: Optimize soft_max

Large soft_max could already saturate memory, but small/medium sizes were
pretty slow. The bulk of the gains for them comes from using a smaller
workgroup size, and making the workgroup size match the subgroup size also
makes the barriers much cheaper.

Cache some values in locals to avoid refetching/recomputing. And stamp
out a few "template instantiations" so smaller cases will fully unroll.

Add a missing early return for OOB rows. This happens when there are more
than 512 rows and the dispatch is 512 x H.

* vulkan: Further soft_max optimizations

Restore the workgroup size of 512 case, use it for >1024.

Use unrollable loops for more iteration counts.
2024-11-20 21:00:08 +02:00
53589c8f12 sycl: Revert MUL_MAT_OP support changes (llama/10385) 2024-11-20 21:00:08 +02:00
7ac2f17fac cuda : only use native when supported by cmake (llama/10389) 2024-11-20 21:00:08 +02:00
48862c7b27 vulkan: remove use of null initializer (llama/10372)
Seems like this isn't working for vulkan-over-metal when the array is sized
by a spec constant. Maybe a spirv-cross limitation?
2024-11-20 21:00:08 +02:00
44f7d9f4e3 metal : fox offset integer overflows in im2col (ggml/1015)
-- While running StableDiffusion.cpp locally with Metal some offsets overflow and results in incorrect calculations
2024-11-20 21:00:08 +02:00
fd12302587 Vulkan: Fix device info output format specifiers (llama/10366)
* Vulkan: Fix device info output format specifiers

* Vulkan: Use zu printf specifier for size_t instead of ld
2024-11-20 21:00:08 +02:00
PAB
f80bef4630 metal : add GGML_UNARY_OP_ELU kernel (ggml/1018) 2024-11-20 21:00:08 +02:00
161b443514 CUDA: fix MMV kernel being used for FP16 src1 (llama/10357) 2024-11-20 21:00:08 +02:00
ef7fbe1c66 CMake: fix typo in comment [no ci] (llama/10360) 2024-11-20 21:00:08 +02:00
0879d3599e llama : only use default buffer types for the KV cache (llama/10358) 2024-11-20 21:00:08 +02:00
2a444dc5bd metal : refactor kernel args into structs (llama/10238)
* metal : add kernel arg structs (wip)

* metal : fattn args

ggml-ci

* metal : cont + avoid potential int overflow [no ci]

* metal : mul mat struct (wip)

* cont : mul mat vec

* cont : pass by reference

* cont : args is first argument

* cont : use char ptr

* cont : shmem style

* cont : thread counters style

* cont : mul mm id

ggml-ci

* cont : int safety + register optimizations

ggml-ci

* metal : GGML_OP_CONCAT

ggml-ci

* metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV

* metal : GGML_OP_REPEAT

* metal : GGML_OP_CPY

* metal : GGML_OP_RMS_NORM

* metal : GGML_OP_NORM

* metal : add TODOs for rest of ops

* ggml : add ggml-metal-impl.h

ggml-ci
2024-11-20 21:00:08 +02:00
45cf1634dc ggml : fix undefined reference to 'getcpu' (llama/10354)
https://github.com/ggerganov/llama.cpp/issues/10352
2024-11-20 21:00:08 +02:00
dcb2922d1d CUDA: remove DMMV, consolidate F16 mult mat vec (llama/10318) 2024-11-20 21:00:08 +02:00
3c5c751174 CMake: default to -arch=native for CUDA build (llama/10320) 2024-11-20 21:00:08 +02:00
24ad19d0e9 ggml : fix possible buffer use after free in sched reserve (llama/9930) 2024-11-20 21:00:08 +02:00
bd574b05af ggml : inttypes.h -> cinttypes (llama/0)
ggml-ci
2024-11-20 21:00:08 +02:00
7e0eafcb1e ggml : adapt AMX to tensor->grad removal (llama/0)
ggml-ci
2024-11-20 21:00:08 +02:00
75670ae673 ggml : fix compile warnings (llama/0)
ggml-ci
2024-11-20 21:00:08 +02:00
d4fcdf602b llamafile : fix include path (llama/0)
ggml-ci
2024-11-20 21:00:08 +02:00
1bebb1a116 vulkan: Optimize some mat-vec mul quant shaders (llama/10296)
Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses
the B loads across the rows and also reuses some addressing calculations.
This required manually partially unrolling the loop, since the compiler
is less willing to unroll outer loops.

Add bounds-checking on the last iteration of the loop. I think this was at
least partly broken before.

Optimize the Q4_K shader to vectorize most loads and reduce the number of
bit twiddling instructions.
2024-11-20 21:00:08 +02:00
ee437cde59 ggml : optimize Q4_0 into Q4_0_X_Y repack (llama/10324) 2024-11-20 21:00:08 +02:00
c1506d38cf Make updates to fix issues with clang-cl builds while using AVX512 flags (llama/10314) 2024-11-20 21:00:08 +02:00
c9541741e6 ggml: new optimization interface (ggml/988)
* ggml: new optimization interface

remove test2.c, test3.c

store adamw params in tensor

move grads from tensor to graph

* avoid segfault upon API misuse

* add ggml-opt.h to public headers

* remove dependence of ggml-opt.cpp on ggml-cpu.h
2024-11-20 21:00:08 +02:00
6a55015dc4 ggml : remove duplicated sources from the last sync (ggml/1017)
* ggml : remove duplicated sources from the last sync

ggml-ci

* cont : remove FindSIMD.cmake [no ci]
2024-11-20 21:00:08 +02:00
7e86030d4d ggml : fix some build issues 2024-11-20 21:00:08 +02:00
401fbea326 sync : leftovers (ggml/0)
ggml-ci
2024-11-20 21:00:08 +02:00
44d1cbdfe9 cmake : restore CMakeLists.txt (llama/10256)
ggml-ci
2024-11-20 21:00:08 +02:00
Eve
3216efef2e AVX BF16 and single scale quant optimizations (llama/10212)
* use 128 bit loads (i've tried 256->128 to death and its slower)

* double accumulator

* avx bf16 vec dot

* +3% q4_0 inference

* +7% tg +5% pp compared to master

* slower f16c version, kep for reference

* 256b version, also slow. i tried :)

* revert f16

* faster with madd

* split to functions

* Q8_0 and IQ4_NL, 5-7% faster

* fix potential overflow (performance reduced)

* 16 bit add for q4_0 only

* merge
2024-11-20 21:00:08 +02:00
2c0484ebf7 sycl: Use syclcompat::dp4a (llama/10267)
* sycl: Use syclcompat::dp4a

* Using the syclcompat version allow the compiler to optimize the
  operation with native function

* Update news section

* Update CI Windows oneAPI version to 2025.0

* Reword doc

* Call syclcompat::dp4a inside dpct::dp4a

This reverts commit 90cb61d692d61360b46954a1c7f780bd2e569b73.
2024-11-20 21:00:08 +02:00
3298916e5e backend cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels (llama/9921)
* backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2024-11-20 21:00:08 +02:00
746bf2596f ggml : build backends as libraries (llama/10256)
* ggml : build backends as libraries

---------

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: R0CKSTAR <xiaodong.ye@mthreads.com>
2024-11-20 21:00:08 +02:00
5f7e094ccb scripts : update sync 2024-11-20 21:00:08 +02:00
356 changed files with 36703 additions and 30446 deletions

View File

@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build
ARG CUDA_DOCKER_ARCH=all
RUN apt-get update && \
apt-get install -y build-essential git cmake libsdl2-dev
apt-get install -y build-essential git cmake libsdl2-dev wget
WORKDIR /app
@ -23,6 +23,6 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
# Enable cuBLAS
ENV GGML_CUDA=1
RUN make
RUN make base.en
ENTRYPOINT ["/app/main"]

View File

@ -17,7 +17,7 @@ ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
ENV GGML_CUDA=1
RUN apt-get update && \
apt-get install -y build-essential libsdl2-dev \
apt-get install -y build-essential libsdl2-dev wget cmake \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
# Ref: https://stackoverflow.com/a/53464012
@ -25,7 +25,7 @@ ENV CUDA_MAIN_VERSION=12.3
ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
COPY .. .
RUN make
RUN make base.en
FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime
ENV CUDA_MAIN_VERSION=12.3
@ -33,7 +33,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg \
apt-get install -y curl ffmpeg wget cmake \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app

View File

@ -2,17 +2,17 @@ FROM ubuntu:22.04 AS build
WORKDIR /app
RUN apt-get update && \
apt-get install -y build-essential \
apt-get install -y build-essential wget cmake \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY .. .
RUN make
RUN make base.en
FROM ubuntu:22.04 AS runtime
WORKDIR /app
RUN apt-get update && \
apt-get install -y curl ffmpeg libsdl2-dev \
apt-get install -y curl ffmpeg libsdl2-dev wget cmake \
&& rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
COPY --from=build /app /app

View File

@ -3,61 +3,41 @@ on:
push:
paths:
- bindings/ruby/**
- src/whisper.cpp
- include/whisper.h
- ggml/src/ggml.c
- ggml/src/ggml-impl.h
- ggml/src/ggml-aarch64.h
- ggml/src/ggml-aarch64.c
- ggml/src/ggml-alloc.c
- ggml/src/ggml-backend-impl.h
- ggml/src/ggml-backend.cpp
- ggml/src/ggml-common.h
- ggml/src/ggml-quants.h
- ggml/src/ggml-quants.c
- ggml/src/ggml-cpu-impl.h
- ggml/src/ggml-metal.m
- ggml/src/ggml-metal.metal
- ggml/src/ggml-blas.cpp
- ggml/include/ggml.h
- ggml/include/ggml-alloc.h
- ggml/include/ggml-backend.h
- ggml/include/ggml-cuda.h
- ggml/include/ggml-kompute.h
- ggml/include/ggml-metal.h
- ggml/include/ggml-sycl.h
- ggml/include/ggml-vulkan.h
- ggml/include/ggml-blas.h
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/dr_wav.h
pull_request:
paths:
- bindings/ruby/**
- src/whisper.cpp
- include/whisper.h
- ggml/src/ggml.c
- ggml/src/ggml-impl.h
- ggml/src/ggml-aarch64.h
- ggml/src/ggml-aarch64.c
- ggml/src/ggml-alloc.c
- ggml/src/ggml-backend-impl.h
- ggml/src/ggml-backend.cpp
- ggml/src/ggml-common.h
- ggml/src/ggml-quants.h
- ggml/src/ggml-quants.c
- ggml/src/ggml-cpu-impl.h
- ggml/src/ggml-metal.m
- ggml/src/ggml-metal.metal
- ggml/src/ggml-blas.cpp
- ggml/include/ggml.h
- ggml/include/ggml-alloc.h
- ggml/include/ggml-backend.h
- ggml/include/ggml-cuda.h
- ggml/include/ggml-kompute.h
- ggml/include/ggml-metal.h
- ggml/include/ggml-sycl.h
- ggml/include/ggml-vulkan.h
- ggml/include/ggml-blas.h
- src/**/*.c
- src/**/*.cpp
- src/**/*.h
- src/**/*.m
- src/**/*.metal
- include/**/*.c
- include/**/*.cpp
- include/**/*.h
- include/**/*.m
- include/**/*.metal
- ggml/**/*.c
- ggml/**/*.cpp
- ggml/**/*.h
- ggml/**/*.m
- ggml/**/*.metal
- scripts/get-flags.mk
- examples/dr_wav.h
@ -70,6 +50,6 @@ jobs:
steps:
- uses: ruby/setup-ruby@v1
with:
ruby-version: '3.0'
ruby-version: '3.1'
- uses: actions/checkout@v4
- run: rake test

View File

@ -28,9 +28,9 @@ jobs:
-w /workspace ${{ env.ubuntu_image }} /bin/sh -c '
set -e
apt update
apt install -y build-essential libsdl2-dev
make
make stream'
apt install -y build-essential libsdl2-dev cmake
cmake -B build
cmake --build build --config Release -j $(nproc)'
macOS-latest:
runs-on: macOS-latest
@ -42,30 +42,30 @@ jobs:
- name: Dependencies
run: |
brew update
brew install sdl2
brew install sdl2 cmake
- name: Build
run: |
make
make stream
cmake -B build
cmake --build build --config Release
freeBSD-latest:
runs-on: macos-12
steps:
- name: Clone
uses: actions/checkout@v4
- name: Build
uses: cross-platform-actions/action@v0.24.0
with:
operating_system: freebsd
version: '13.3'
run: |
sudo pkg update
sudo pkg install -y gmake sdl2
gmake
gmake stream
# freeBSD-latest:
# runs-on: macos-12
#
# steps:
# - name: Clone
# uses: actions/checkout@v4
#
# - name: Build
# uses: cross-platform-actions/action@v0.24.0
# with:
# operating_system: freebsd
# version: '13.3'
# run: |
# sudo pkg update
# sudo pkg install -y gmake sdl2 cmake
# cmake -B build
# cmake --build build --config Release
ubuntu-latest-gcc:
runs-on: ubuntu-latest
@ -280,25 +280,10 @@ jobs:
mingw-w64-${{matrix.env}}-SDL2
mingw-w64-${{matrix.env}}-openblas
- name: Build using make
shell: msys2 {0}
run: |
make -j $(nproc)
- name: Clean after building using make
shell: msys2 {0}
run: |
make clean
- name: Build using make w/ OpenBLAS
shell: msys2 {0}
run: |
make GGML_OPENBLAS=1 -j $(nproc)
- name: Build using CMake
shell: msys2 {0}
run: |
cmake -B build
cmake -B build -DWHISPER_SDL2=ON
cmake --build build --config ${{ matrix.build }} -j $(nproc)
- name: Clean after building using CMake
@ -445,71 +430,72 @@ jobs:
name: whisper-blas-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
windows-cublas:
runs-on: windows-2019
strategy:
matrix:
build: [Release]
arch: [x64]
cublas: [ON]
sdl2: [ON]
cuda-toolkit: [12.2.0, 11.8.0]
include:
- arch: x64
s2arc: x64
- sdl2: ON
s2ver: 2.28.5
steps:
- name: Clone
uses: actions/checkout@v4
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2
- name: Install CUDA Toolkit
id: cuda-toolkit
uses: Jimver/cuda-toolkit@v0.2.15
with:
cuda: '${{ matrix.cuda-toolkit }}'
- name: Fetch SDL2 and set SDL2_DIR
if: matrix.sdl2 == 'ON'
run: |
C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
7z x sdl2.zip
echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
- name: Configure
run: >
cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DGGML_CUDA=${{ matrix.cublas }}
-DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build ${{ matrix.cuda-toolkit }}
run: |
cd ./build
cmake --build . --config ${{ matrix.build }}
- name: Copy CUDA DLLs
run: >
Copy-Item -PassThru
-Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
-Include cudart64_*,cublas64_*,cublasLt64_*
-Destination build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.sdl2 == 'ON'
uses: actions/upload-artifact@v4
with:
name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
path: build/bin/${{ matrix.build }}
# TODO: fix and re-enable
# windows-cublas:
# runs-on: windows-2019
#
# strategy:
# matrix:
# build: [Release]
# arch: [x64]
# cublas: [ON]
# sdl2: [ON]
# cuda-toolkit: [12.2.0, 11.8.0]
# include:
# - arch: x64
# s2arc: x64
# - sdl2: ON
# s2ver: 2.28.5
#
# steps:
# - name: Clone
# uses: actions/checkout@v4
#
# - name: Add msbuild to PATH
# uses: microsoft/setup-msbuild@v2
#
# - name: Install CUDA Toolkit
# id: cuda-toolkit
# uses: Jimver/cuda-toolkit@v0.2.15
# with:
# cuda: '${{ matrix.cuda-toolkit }}'
#
# - name: Fetch SDL2 and set SDL2_DIR
# if: matrix.sdl2 == 'ON'
# run: |
# C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip
# 7z x sdl2.zip
# echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV
#
# - name: Configure
# run: >
# cmake -S . -B ./build -A ${{ matrix.arch }}
# -DCMAKE_BUILD_TYPE=${{ matrix.build }}
# -DGGML_CUDA=${{ matrix.cublas }}
# -DWHISPER_SDL2=${{ matrix.sdl2 }}
#
# - name: Build ${{ matrix.cuda-toolkit }}
# run: |
# cd ./build
# cmake --build . --config ${{ matrix.build }}
#
# - name: Copy CUDA DLLs
# run: >
# Copy-Item -PassThru
# -Path "${{ steps.cuda-toolkit.outputs.CUDA_PATH }}/bin/*.dll"
# -Include cudart64_*,cublas64_*,cublasLt64_*
# -Destination build/bin/${{ matrix.build }}
#
# - name: Copy SDL2.dll
# if: matrix.sdl2 == 'ON'
# run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
#
# - name: Upload binaries
# if: matrix.sdl2 == 'ON'
# uses: actions/upload-artifact@v4
# with:
# name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}
# path: build/bin/${{ matrix.build }}
emscripten:
runs-on: ubuntu-latest
@ -533,7 +519,7 @@ jobs:
emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
make
ios:
ios-xcode-build:
runs-on: macos-latest
strategy:
@ -541,7 +527,7 @@ jobs:
build: [Release]
steps:
- name: Clone
- name: Checkout code
uses: actions/checkout@v4
- name: Configure
@ -549,11 +535,34 @@ jobs:
cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin
mkdir models/ggml-base.en-encoder.mlmodelc
- name: Build objc example
run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphonesimulator build
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DWHISPER_BUILD_EXAMPLES=OFF \
-DWHISPER_BUILD_TESTS=OFF \
-DWHISPER_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
sudo cmake --install . --config Release
- name: xcodebuild for swift package
id: xcodebuild
run: |
xcodebuild -scheme whisper-Package -destination 'generic/platform=iOS'
#- name: Build objc example
# run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos build
- name: Build swiftui example
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphonesimulator build
run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build
android:
runs-on: ubuntu-latest
@ -664,5 +673,6 @@ jobs:
- name: Test quantize
run: |
./models/download-ggml-model.sh tiny.en
make quantize
./quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0
cmake -B build
cmake --build build --config Release
./build/bin/quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0

4
.gitignore vendored
View File

@ -1,5 +1,6 @@
*.o
*.a
*.d
.cache/
.coreml/
.test/
@ -19,6 +20,9 @@ build-*/
.swiftpm
*.metallib
ggml-metal-embed.metal
ggml-metal-embed.metal.tmp
/main
/stream
/command

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories.
project("whisper.cpp" C CXX)
project("whisper.cpp" VERSION 1.7.2)
project("whisper.cpp" VERSION 1.7.3)
include(CheckIncludeFileCXX)
set(SOVERSION 1)

1138
Makefile

File diff suppressed because it is too large Load Diff

View File

@ -14,49 +14,6 @@ let package = Package(
.library(name: "whisper", targets: ["whisper"]),
],
targets: [
.target(
name: "whisper",
path: ".",
exclude: [
"build",
"bindings",
"cmake",
"examples",
"scripts",
"models",
"samples",
"tests",
"CMakeLists.txt",
"Makefile",
"ggml/src/ggml-metal-embed.metal"
],
sources: [
"ggml/src/ggml.c",
"src/whisper.cpp",
"ggml/src/ggml-aarch64.c",
"ggml/src/ggml-alloc.c",
"ggml/src/ggml-backend.cpp",
"ggml/src/ggml-cpu.c",
"ggml/src/ggml-quants.c",
"ggml/src/ggml-metal.m"
],
resources: [.process("ggml/src/ggml-metal.metal")],
publicHeadersPath: "spm-headers",
cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
.define("GGML_USE_ACCELERATE"),
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_USE_METAL")
// NOTE: NEW_LAPACK will required iOS version 16.4+
// We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
],
linkerSettings: [
.linkedFramework("Accelerate")
]
)
],
cxxLanguageStandard: .cxx11
.systemLibrary(name: "whisper", pkgConfig: "whisper"),
]
)

343
README.md
View File

@ -7,7 +7,7 @@
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.7.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.2) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
Stable: [v1.7.3](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.7.3) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
@ -53,18 +53,6 @@ On Apple Silicon, the inference runs fully on the GPU via Metal:
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
## Implementation details
- The core tensor operations are implemented in C ([ggml.h](ggml/include/ggml.h) / [ggml.c](ggml/src/ggml.c))
- The transformer model and the high-level C-style API are implemented in C++ ([whisper.h](include/whisper.h) / [whisper.cpp](src/whisper.cpp))
- Sample usage is demonstrated in [main.cpp](examples/main)
- Sample real-time audio transcription from the microphone is demonstrated in [stream.cpp](examples/stream)
- Various other examples are available in the [examples](examples) folder
The tensor operators are optimized heavily for Apple silicon CPUs. Depending on the computation size, Arm Neon SIMD intrinsics or CBLAS Accelerate framework routines are used. The latter are especially effective for bigger sizes since the Accelerate framework utilizes the special-purpose AMX coprocessor available in modern Apple products.
## Quick start
First clone the repository:
@ -85,134 +73,26 @@ Then, download one of the Whisper [models](models/README.md) converted in [`ggml
sh ./models/download-ggml-model.sh base.en
```
Now build the [main](examples/main) example and transcribe an audio file like this:
Now build the [whisper-cli](examples/cli) example and transcribe an audio file like this:
```bash
# build the main example
make -j
# build the project
cmake -B build
cmake --build build --config Release
# transcribe an audio file
./main -f samples/jfk.wav
./build/bin/whisper-cli -f samples/jfk.wav
```
---
For a quick demo, simply run `make base.en`:
```text
$ make -j base.en
cc -I. -O3 -std=c11 -pthread -DGGML_USE_ACCELERATE -c ggml.c -o ggml.o
c++ -I. -I./examples -O3 -std=c++11 -pthread -c whisper.cpp -o whisper.o
c++ -I. -I./examples -O3 -std=c++11 -pthread examples/main/main.cpp whisper.o ggml.o -o main -framework Accelerate
./main -h
usage: ./main [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
-t N, --threads N [4 ] number of threads to use during computation
-p N, --processors N [1 ] number of processors to use during computation
-ot N, --offset-t N [0 ] time offset in milliseconds
-on N, --offset-n N [0 ] segment index offset
-d N, --duration N [0 ] duration of audio to process in milliseconds
-mc N, --max-context N [-1 ] maximum number of text context tokens to store
-ml N, --max-len N [0 ] maximum segment length in characters
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
-tdrz, --tinydiarize [false ] enable tinydiarize (requires a tdrz model)
-nf, --no-fallback [false ] do not use temperature fallback while decoding
-otxt, --output-txt [false ] output result in a text file
-ovtt, --output-vtt [false ] output result in a vtt file
-osrt, --output-srt [false ] output result in a srt file
-olrc, --output-lrc [false ] output result in a lrc file
-owts, --output-words [false ] output script for generating karaoke video
-fp, --font-path [/System/Library/Fonts/Supplemental/Courier New Bold.ttf] path to a monospace font for karaoke video
-ocsv, --output-csv [false ] output result in a CSV file
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
sh ./models/download-ggml-model.sh base.en
Downloading ggml model base.en ...
ggml-base.en.bin 100%[========================>] 141.11M 6.34MB/s in 24s
Done! Model 'base.en' saved in 'models/ggml-base.en.bin'
You can now use it like this:
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
===============================================
Running base.en on all samples in ./samples ...
===============================================
----------------------------------------------
[+] Running base.en on samples/jfk.wav ... (run 'ffplay samples/jfk.wav' to listen)
----------------------------------------------
whisper_init_from_file: loading model from 'models/ggml-base.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 512
whisper_model_load: n_audio_head = 8
whisper_model_load: n_audio_layer = 6
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 512
whisper_model_load: n_text_head = 8
whisper_model_load: n_text_layer = 6
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 2
whisper_model_load: mem required = 215.00 MB (+ 6.00 MB per decoder)
whisper_model_load: kv self size = 5.25 MB
whisper_model_load: kv cross size = 17.58 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 140.60 MB
whisper_model_load: model size = 140.54 MB
system_info: n_threads = 4 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:11.000] And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.
whisper_print_timings: fallbacks = 0 p / 0 h
whisper_print_timings: load time = 113.81 ms
whisper_print_timings: mel time = 15.40 ms
whisper_print_timings: sample time = 11.58 ms / 27 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 266.60 ms / 1 runs ( 266.60 ms per run)
whisper_print_timings: decode time = 66.11 ms / 27 runs ( 2.45 ms per run)
whisper_print_timings: total time = 476.31 ms
```
For a quick demo, simply run `make base.en`.
The command downloads the `base.en` model converted to custom `ggml` format and runs the inference on all `.wav` samples in the folder `samples`.
For detailed usage instructions, run: `./main -h`
For detailed usage instructions, run: `./build/bin/whisper-cli -h`
Note that the [main](examples/main) example currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
Note that the [whisper-cli](examples/cli) example currently runs only with 16-bit WAV files, so make sure to convert your input before running the tool.
For example, you can use `ffmpeg` like this:
```bash
@ -265,11 +145,12 @@ Here are the steps for creating and using a quantized model:
```bash
# quantize a model with Q5_0 method
make -j quantize
./quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
cmake -B build
cmake --build build --config Release
./build/bin/quantize models/ggml-base.en.bin models/ggml-base.en-q5_0.bin q5_0
# run the examples as usual, specifying the quantized model file
./main -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
./build/bin/whisper-cli -m models/ggml-base.en-q5_0.bin ./samples/gb0.wav
```
## Core ML support
@ -303,10 +184,6 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
- Build `whisper.cpp` with Core ML support:
```bash
# using Makefile
make clean
WHISPER_COREML=1 make -j
# using CMake
cmake -B build -DWHISPER_COREML=1
cmake --build build -j --config Release
@ -315,7 +192,7 @@ speed-up - more than x3 faster compared with CPU-only execution. Here are the in
- Run the examples as usual. For example:
```text
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
$ ./build/bin/whisper-cli -m models/ggml-base.en.bin -f samples/jfk.wav
...
@ -399,7 +276,7 @@ This can result in significant speedup in encoder performance. Here are the inst
- Run the examples as usual. For example:
```text
$ ./main -m models/ggml-base.en.bin -f samples/jfk.wav
$ ./build/bin/whisper-cli -m models/ggml-base.en.bin -f samples/jfk.wav
...
@ -426,8 +303,8 @@ First, make sure you have installed `cuda`: https://developer.nvidia.com/cuda-do
Now build `whisper.cpp` with CUDA support:
```
make clean
GGML_CUDA=1 make -j
cmake -B build -DGGML_CUDA=1
cmake --build build -j --config Release
```
## Vulkan GPU support
@ -436,8 +313,8 @@ First, make sure your graphics card driver provides support for Vulkan API.
Now build `whisper.cpp` with Vulkan support:
```
make clean
make GGML_VULKAN=1 -j
cmake -B build -DGGML_VULKAN=1
cmake --build build -j --config Release
```
## BLAS CPU support via OpenBLAS
@ -448,28 +325,13 @@ First, make sure you have installed `openblas`: https://www.openblas.net/
Now build `whisper.cpp` with OpenBLAS support:
```
make clean
GGML_OPENBLAS=1 make -j
```
## BLAS CPU support via Intel MKL
Encoder processing can be accelerated on the CPU via the BLAS compatible interface of Intel's Math Kernel Library.
First, make sure you have installed Intel's MKL runtime and development packages: https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl-download.html
Now build `whisper.cpp` with Intel MKL BLAS support:
```
source /opt/intel/oneapi/setvars.sh
mkdir build
cd build
cmake -DWHISPER_MKL=ON ..
WHISPER_MKL=1 make -j
cmake -B build -DGGML_BLAS=1
cmake --build build -j --config Release
```
## Ascend NPU support
Ascend NPU provides inference acceleration via [`CANN`](https://www.hiascend.com/en/software/cann) and AI cores.
Ascend NPU provides inference acceleration via [`CANN`](https://www.hiascend.com/en/software/cann) and AI cores.
First, check if your Ascend NPU device is supported:
@ -483,16 +345,14 @@ Then, make sure you have installed [`CANN toolkit`](https://www.hiascend.com/en/
Now build `whisper.cpp` with CANN support:
```
mkdir build
cd build
cmake .. -D GGML_CANN=on
make -j
cmake -B build -DGGML_CANN=1
cmake --build build -j --config Release
```
Run the inference examples as usual, for example:
```
./build/bin/main -f samples/jfk.wav -m models/ggml-base.en.bin -t 8
./build/bin/whisper-cli -f samples/jfk.wav -m models/ggml-base.en.bin -t 8
```
*Notes:*
@ -500,38 +360,6 @@ Run the inference examples as usual, for example:
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
- If you run successfully with your Ascend NPU device, please help update the table `Verified devices`.
## Docker
### Prerequisites
- Docker must be installed and running on your system.
- Create a folder to store big models & intermediate files (ex. /whisper/models)
### Images
We have two Docker images available for this project:
1. `ghcr.io/ggerganov/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`)
2. `ghcr.io/ggerganov/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`)
### Usage
```shell
# download model and persist it in a local folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "./models/download-ggml-model.sh base /models"
# transcribe an audio file
docker run -it --rm \
-v path/to/models:/models \
-v path/to/audios:/audios \
whisper.cpp:main "./main -m /models/ggml-base.bin -f /audios/jfk.wav"
# transcribe an audio file in samples folder
docker run -it --rm \
-v path/to/models:/models \
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
```
## Installing with Conan
You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command:
@ -546,89 +374,6 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
- Inference only
## Another example
Here is another example of transcribing a [3:24 min speech](https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg)
in about half a minute on a MacBook M1 Pro, using `medium.en` model:
<details>
<summary>Expand to see the result</summary>
```text
$ ./main -m models/ggml-medium.en.bin -f samples/gb1.wav -t 8
whisper_init_from_file: loading model from 'models/ggml-medium.en.bin'
whisper_model_load: loading model
whisper_model_load: n_vocab = 51864
whisper_model_load: n_audio_ctx = 1500
whisper_model_load: n_audio_state = 1024
whisper_model_load: n_audio_head = 16
whisper_model_load: n_audio_layer = 24
whisper_model_load: n_text_ctx = 448
whisper_model_load: n_text_state = 1024
whisper_model_load: n_text_head = 16
whisper_model_load: n_text_layer = 24
whisper_model_load: n_mels = 80
whisper_model_load: f16 = 1
whisper_model_load: type = 4
whisper_model_load: mem required = 1720.00 MB (+ 43.00 MB per decoder)
whisper_model_load: kv self size = 42.00 MB
whisper_model_load: kv cross size = 140.62 MB
whisper_model_load: adding 1607 extra tokens
whisper_model_load: model ctx = 1462.35 MB
whisper_model_load: model size = 1462.12 MB
system_info: n_threads = 8 / 10 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
main: processing 'samples/gb1.wav' (3179750 samples, 198.7 sec), 8 threads, 1 processors, lang = en, task = transcribe, timestamps = 1 ...
[00:00:00.000 --> 00:00:08.000] My fellow Americans, this day has brought terrible news and great sadness to our country.
[00:00:08.000 --> 00:00:17.000] At nine o'clock this morning, Mission Control in Houston lost contact with our Space Shuttle Columbia.
[00:00:17.000 --> 00:00:23.000] A short time later, debris was seen falling from the skies above Texas.
[00:00:23.000 --> 00:00:29.000] The Columbia's lost. There are no survivors.
[00:00:29.000 --> 00:00:32.000] On board was a crew of seven.
[00:00:32.000 --> 00:00:39.000] Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark,
[00:00:39.000 --> 00:00:48.000] Captain David Brown, Commander William McCool, Dr. Kultna Shavla, and Ilan Ramon,
[00:00:48.000 --> 00:00:52.000] a colonel in the Israeli Air Force.
[00:00:52.000 --> 00:00:58.000] These men and women assumed great risk in the service to all humanity.
[00:00:58.000 --> 00:01:03.000] In an age when space flight has come to seem almost routine,
[00:01:03.000 --> 00:01:07.000] it is easy to overlook the dangers of travel by rocket
[00:01:07.000 --> 00:01:12.000] and the difficulties of navigating the fierce outer atmosphere of the Earth.
[00:01:12.000 --> 00:01:18.000] These astronauts knew the dangers, and they faced them willingly,
[00:01:18.000 --> 00:01:23.000] knowing they had a high and noble purpose in life.
[00:01:23.000 --> 00:01:31.000] Because of their courage and daring and idealism, we will miss them all the more.
[00:01:31.000 --> 00:01:36.000] All Americans today are thinking as well of the families of these men and women
[00:01:36.000 --> 00:01:40.000] who have been given this sudden shock and grief.
[00:01:40.000 --> 00:01:45.000] You're not alone. Our entire nation grieves with you,
[00:01:45.000 --> 00:01:52.000] and those you love will always have the respect and gratitude of this country.
[00:01:52.000 --> 00:01:56.000] The cause in which they died will continue.
[00:01:56.000 --> 00:02:04.000] Mankind is led into the darkness beyond our world by the inspiration of discovery
[00:02:04.000 --> 00:02:11.000] and the longing to understand. Our journey into space will go on.
[00:02:11.000 --> 00:02:16.000] In the skies today, we saw destruction and tragedy.
[00:02:16.000 --> 00:02:22.000] Yet farther than we can see, there is comfort and hope.
[00:02:22.000 --> 00:02:29.000] In the words of the prophet Isaiah, "Lift your eyes and look to the heavens
[00:02:29.000 --> 00:02:35.000] who created all these. He who brings out the starry hosts one by one
[00:02:35.000 --> 00:02:39.000] and calls them each by name."
[00:02:39.000 --> 00:02:46.000] Because of His great power and mighty strength, not one of them is missing.
[00:02:46.000 --> 00:02:55.000] The same Creator who names the stars also knows the names of the seven souls we mourn today.
[00:02:55.000 --> 00:03:01.000] The crew of the shuttle Columbia did not return safely to earth,
[00:03:01.000 --> 00:03:05.000] yet we can pray that all are safely home.
[00:03:05.000 --> 00:03:13.000] May God bless the grieving families, and may God continue to bless America.
[00:03:13.000 --> 00:03:19.000] [Silence]
whisper_print_timings: fallbacks = 1 p / 0 h
whisper_print_timings: load time = 569.03 ms
whisper_print_timings: mel time = 146.85 ms
whisper_print_timings: sample time = 238.66 ms / 553 runs ( 0.43 ms per run)
whisper_print_timings: encode time = 18665.10 ms / 9 runs ( 2073.90 ms per run)
whisper_print_timings: decode time = 13090.93 ms / 549 runs ( 23.85 ms per run)
whisper_print_timings: total time = 32733.52 ms
```
</details>
## Real-time audio input example
This is a naive example of performing real-time inference on audio from your microphone.
@ -636,8 +381,9 @@ The [stream](examples/stream) tool samples the audio every half a second and run
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```bash
make stream -j
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
cmake -B build
cmake --build build --config Release
./build/bin/stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
@ -648,7 +394,7 @@ Adding the `--print-colors` argument will print the transcribed text using an ex
to highlight words with high or low confidence:
```bash
./main -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
./build/bin/whisper-cli -m models/ggml-base.en.bin -f samples/gb0.wav --print-colors
```
<img width="965" alt="image" src="https://user-images.githubusercontent.com/1991296/197356445-311c8643-9397-4e5e-b46e-0b4b4daa2530.png">
@ -658,7 +404,7 @@ to highlight words with high or low confidence:
For example, to limit the line length to a maximum of 16 characters, simply add `-ml 16`:
```text
$ ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
$ ./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 16
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
@ -682,7 +428,7 @@ main: processing './samples/jfk.wav' (176000 samples, 11.0 sec), 4 threads, 1 pr
The `--max-len` argument can be used to obtain word-level timestamps. Simply use `-ml 1`:
```text
$ ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
$ ./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -ml 1
whisper_model_load: loading model from './models/ggml-base.en.bin'
...
@ -729,7 +475,7 @@ Sample usage:
./models/download-ggml-model.sh small.en-tdrz
# run as usual, adding the "-tdrz" command-line argument
./main -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
./build/bin/whisper-cli -f ./samples/a13.wav -m ./models/ggml-small.en-tdrz.bin -tdrz
...
main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 processors, lang = en, task = transcribe, tdrz = 1, timestamps = 1 ...
...
@ -746,14 +492,14 @@ main: processing './samples/a13.wav' (480000 samples, 30.0 sec), 4 threads, 1 pr
## Karaoke-style movie generation (experimental)
The [main](examples/main) example provides support for output of karaoke-style movies, where the
The [whisper-cli](examples/cli) example provides support for output of karaoke-style movies, where the
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed.
Here are a few _"typical"_ examples:
```bash
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts
source ./samples/jfk.wav.wts
ffplay ./samples/jfk.wav.mp4
```
@ -763,7 +509,7 @@ https://user-images.githubusercontent.com/1991296/199337465-dbee4b5e-9aeb-48a3-b
---
```bash
./main -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/mm0.wav -owts
source ./samples/mm0.wav.wts
ffplay ./samples/mm0.wav.mp4
```
@ -773,7 +519,7 @@ https://user-images.githubusercontent.com/1991296/199337504-cc8fd233-0cb7-4920-9
---
```bash
./main -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
./build/bin/whisper-cli -m ./models/ggml-base.en.bin -f ./samples/gb0.wav -owts
source ./samples/gb0.wav.wts
ffplay ./samples/gb0.wav.mp4
```
@ -798,7 +544,7 @@ https://user-images.githubusercontent.com/1991296/223206245-2d36d903-cf8e-4f09-8
## Benchmarks
In order to have an objective comparison of the performance of the inference across different system configurations,
use the [bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
use the [whisper-bench](examples/bench) tool. The tool simply runs the Encoder part of the model and prints how much time it
took to execute it. The results are summarized in the following Github issue:
[Benchmark results](https://github.com/ggerganov/whisper.cpp/issues/89)
@ -861,13 +607,12 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| Example | Web | Description |
| --------------------------------------------------- | ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
| [main](examples/main) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
| [bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
| [talk](examples/talk) | [talk.wasm](examples/talk.wasm) | Talk with a GPT-2 bot |
| [talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper-cli](examples/cli) | [whisper.wasm](examples/whisper.wasm) | Tool for translating and transcribing audio using Whisper |
| [whisper-bench](examples/bench) | [bench.wasm](examples/bench.wasm) | Benchmark the performance of Whisper on your machine |
| [whisper-stream](examples/stream) | [stream.wasm](examples/stream.wasm) | Real-time transcription of raw microphone capture |
| [whisper-command](examples/command) | [command.wasm](examples/command.wasm) | Basic voice assistant example for receiving voice commands from the mic |
| [whisper-server](examples/server) | | HTTP transcription server with OAI-like API |
| [whisper-talk-llama](examples/talk-llama) | | Talk with a LLaMA bot |
| [whisper.objc](examples/whisper.objc) | | iOS mobile application using whisper.cpp |
| [whisper.swiftui](examples/whisper.swiftui) | | SwiftUI iOS / macOS application using whisper.cpp |
| [whisper.android](examples/whisper.android) | | Android mobile application using whisper.cpp |
@ -875,7 +620,7 @@ Some of the examples are even ported to run in the browser using WebAssembly. Ch
| [generate-karaoke.sh](examples/generate-karaoke.sh) | | Helper script to easily [generate a karaoke video](https://youtu.be/uj7hVta4blM) of raw audio capture |
| [livestream.sh](examples/livestream.sh) | | [Livestream audio transcription](https://github.com/ggerganov/whisper.cpp/issues/185) |
| [yt-wsp.sh](examples/yt-wsp.sh) | | Download + transcribe and/or translate any VOD [(original)](https://gist.github.com/DaniruKun/96f763ec1a037cc92fe1a059b643b818) |
| [server](examples/server) | | HTTP transcription server with OAI-like API |
| [wchess](examples/wchess) | [wchess.wasm](examples/wchess) | Voice-controlled chess |
## [Discussions](https://github.com/ggerganov/whisper.cpp/discussions)

View File

@ -0,0 +1,5 @@
module whisper [system] {
header "whisper.h"
link "whisper"
export *
}

View File

@ -0,0 +1,4 @@
#pragma once
#include <whisper.h>

View File

@ -67,5 +67,5 @@ copy /y ..\..\build\bin\Release\whisper.dll build\generated\resources\main\win32
## License
The license for the Go bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.
The license for the Java bindings is the same as the license for the rest of the whisper.cpp project, which is the MIT License. See the `LICENSE` file for more details.

View File

@ -1,6 +1,6 @@
{
"name": "whisper.cpp",
"version": "1.7.2",
"version": "1.7.3",
"description": "Whisper speech recognition",
"main": "whisper.js",
"scripts": {

View File

@ -1,3 +1,5 @@
LICENSE
pkg/
lib/whisper.*
lib/whisper.so
lib/whisper.bundle
lib/whisper.dll

View File

@ -22,7 +22,7 @@ Usage
```ruby
require "whisper"
whisper = Whisper::Context.new("path/to/model.bin")
whisper = Whisper::Context.new("base")
params = Whisper::Params.new
params.language = "en"
@ -41,21 +41,66 @@ end
### Preparing model ###
Use script to download model file(s):
Some models are prepared up-front:
```bash
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
sh ./models/download-ggml-model.sh base.en
```ruby
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```
There are some types of models. See [models][] page for details.
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model.pre_converted_models["base"].clear_cache
```
You also can use shorthand for pre-converted models:
```ruby
whisper = Whisper::Context.new("base.en")
```
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
```ruby
puts Whisper::Model.preconverted_model_names
# tiny
# tiny.en
# tiny-q5_1
# tiny.en-q5_1
# tiny-q8_0
# base
# base.en
# base-q5_1
# base.en-q5_1
# base-q8_0
# :
# :
```
You can also use local model files you prepared:
```ruby
whisper = Whisper::Context.new("path/to/your/model.bin")
```
Or, you can download model files:
```ruby
model_uri = Whisper::Model::URI.new("http://example.net/uri/of/your/model.bin")
whisper = Whisper::Context.new(model_uri)
```
See [models][] page for details.
### Preparing audio file ###
Currently, whisper.cpp accepts only 16-bit WAV files.
### API ###
API
---
### Segments ###
Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`:
@ -85,13 +130,6 @@ end
You can also add hook to params called on new segment:
```ruby
def format_time(time_ms)
sec, decimal_part = time_ms.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
end
# Add hook before calling #transcribe
params.on_new_segment do |segment|
line = "[%{st} --> %{ed}] %{text}" % {
@ -107,10 +145,12 @@ whisper.transcribe("path/to/audio.wav", params)
```
### Models ###
You can see model information:
```ruby
whisper = Whisper::Context.new("path/to/model.bin")
whisper = Whisper::Context.new("base")
model = whisper.model
model.n_vocab # => 51864
@ -128,6 +168,8 @@ model.type # => "base"
```
### Logging ###
You can set log callback:
```ruby
@ -157,9 +199,29 @@ Using this feature, you are also able to suppress log:
Whisper.log_set ->(level, buffer, user_data) {
# do nothing
}, nil
Whisper::Context.new(MODEL)
Whisper::Context.new("base")
```
### Low-level API to transcribe ###
You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility.
```ruby
require "whisper"
require "wavefile"
reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
whisper = Whisper::Context.new("base")
whisper.full(Whisper::Params.new, samples)
whisper.each_segment do |segment|
puts segment.text
end
```
The second argument `samples` may be an array, an object with `length` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.
License
-------

View File

@ -1,39 +1,30 @@
require 'rake/clean'
require "bundler/gem_tasks"
require "pathname"
require "yaml"
require "rake/testtask"
require_relative "extsources"
extsources = YAML.load_file("extsources.yaml")
SOURCES = FileList[]
extsources.each do |src|
EXTSOURCES.each do |src|
basename = src.pathmap("%f")
dest = basename == "LICENSE" ? basename : basename.pathmap("ext/%f")
dest = basename == "LICENSE" ? basename : src.pathmap("%{../..,ext}p")
dir = dest.pathmap("%d")
file src
file dest => src do |t|
directory dir
file dest => [src, dir] do |t|
cp t.source, t.name
end
SOURCES.include dest
end
CLEAN.include SOURCES
CLEAN.include FileList[
"ext/*.o",
"ext/*.metal",
"ext/whisper.{so,bundle,dll}",
"ext/depend"
]
task build: FileList[
"ext/Makefile",
"ext/ruby_whisper.h",
"ext/ruby_whisper.cpp",
"whispercpp.gemspec",
]
CLEAN.include SOURCES
CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
directory "pkg"
CLOBBER.include "pkg"
TEST_MODEL = "../../models/ggml-base.en.bin"
LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
SO_FILE = File.join("ext", LIB_NAME)
LIB_FILE = File.join("lib", LIB_NAME)
@ -49,20 +40,25 @@ file SO_FILE => "ext/Makefile" do |t|
sh "make"
end
end
CLEAN.include LIB_FILE
CLEAN.include SO_FILE
directory "lib"
file LIB_FILE => [SO_FILE, "lib"] do |t|
copy t.source, t.name
end
CLEAN.include LIB_FILE
Rake::TestTask.new do |t|
t.test_files = FileList["tests/test_*.rb"]
end
task test: [TEST_MODEL, LIB_FILE]
file TEST_MODEL do
Dir.chdir "../.." do
sh "./models/download-ggml-model.sh base.en"
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
Dir.chdir "tests/jfk_reader" do
ruby "extconf.rb"
sh "make"
end
end
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
task test: [LIB_FILE, TEST_MEMORY_VIEW]

View File

@ -1,35 +1,13 @@
Makefile
ggml.c
ggml.h
ggml-alloc.c
ggml-alloc.h
ggml-aarch64.c
ggml-aarch64.h
ggml-backend.cpp
ggml-backend-impl.h
ggml-backend.c
ggml-backend.h
ggml-common.h
ggml-cpu-impl.h
ggml-metal.m
ggml-metal.metal
ggml-metal-embed.metal
ggml-blas.cpp
ggml-cuda.h
ggml-impl.h
ggml-kompute.h
ggml-metal.h
ggml-opencl.h
ggml-quants.c
ggml-quants.h
ggml-sycl.h
ggml-vulkan.h
ggml-blas.h
get-flags.mk
whisper.cpp
whisper.h
dr_wav.h
depend
whisper.bundle
whisper.so
whisper.bundle
whisper.dll
scripts/get-flags.mk
*.o
*.c
*.cpp
*.h
*.m
*.metal
!ruby_whisper.cpp
!ruby_whisper.h

9
bindings/ruby/ext/cpu.mk Normal file
View File

@ -0,0 +1,9 @@
ggml/src/ggml-cpu/ggml-cpu-cpp.o: \
ggml/src/ggml-cpu/ggml-cpu.cpp \
ggml/include/ggml-backend.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/src/ggml-backend-impl.h \
ggml/include/ggml-cpu.h \
ggml/src/ggml-impl.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@ -1,7 +1,7 @@
require 'mkmf'
# need to use c++ compiler flags
$CXXFLAGS << ' -std=c++11'
$CXXFLAGS << ' -std=c++17'
$LDFLAGS << ' -lstdc++'
@ -35,10 +35,10 @@ if $GGML_METAL
$GGML_METAL_EMBED_LIBRARY = true
end
$MK_CPPFLAGS = ''
$MK_CPPFLAGS = '-Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -Iexamples'
$MK_CFLAGS = '-std=c11 -fPIC'
$MK_CXXFLAGS = '-std=c++11 -fPIC'
$MK_NVCCFLAGS = '-std=c++11'
$MK_CXXFLAGS = '-std=c++17 -fPIC'
$MK_NVCCFLAGS = '-std=c++17'
$MK_LDFLAGS = ''
$OBJ_GGML = []
@ -111,11 +111,6 @@ unless ENV['RISCV']
$MK_CFLAGS << ' -march=native -mtune=native'
$HOST_CXXFLAGS << ' -march=native -mtune=native'
end
if $UNAME_M.match? /aarch64.*/
$MK_CFLAGS << ' -mcpu=native'
$MK_CXXFLAGS << ' -mcpu=native'
end
else
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
@ -123,11 +118,11 @@ end
unless ENV['GGML_NO_ACCELERATE']
if $UNAME_S == 'Darwin'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS'
$MK_CPPFLAGS << ' -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE'
$MK_CPPFLAGS << ' -DACCELERATE_NEW_LAPACK'
$MK_CPPFLAGS << ' -DACCELERATE_LAPACK_ILP64'
$MK_LDFLAGS << ' -framework Accelerate'
$OBJ_GGML << 'ggml-blas.o'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
end
@ -135,20 +130,20 @@ if ENV['GGML_OPENBLAS']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas`}"
$OBJ_GGML << 'ggml-blas.o'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if ENV['GGML_OPENBLAS64']
$MK_CPPFLAGS << " -DGGML_USE_BLAS #{`pkg-config --cflags-only-I openblas64`.chomp}"
$MK_CFLAGS << " #{`pkg-config --cflags-only-other openblas64)`.chomp}"
$MK_LDFLAGS << " #{`pkg-config --libs openblas64`}"
$OBJ_GGML << 'ggml-blas.o'
$OBJ_GGML << 'ggml/src/ggml-blas/ggml-blas.o'
end
if $GGML_METAL
$MK_CPPFLAGS << ' -DGGML_USE_METAL'
$MK_LDFLAGS << ' -framework Foundation -framework Metal -framework MetalKit'
$OBJ_GGML << 'ggml-metal.o'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal.o'
if ENV['GGML_METAL_NDEBUG']
$MK_CPPFLAGS << ' -DGGML_METAL_NDEBUG'
@ -156,20 +151,27 @@ if $GGML_METAL
if $GGML_METAL_EMBED_LIBRARY
$MK_CPPFLAGS << ' -DGGML_METAL_EMBED_LIBRARY'
$OBJ_GGML << 'ggml-metal-embed.o'
$OBJ_GGML << 'ggml/src/ggml-metal/ggml-metal-embed.o'
end
end
$OBJ_GGML <<
'ggml.o' <<
'ggml-cpu.o' <<
'ggml-alloc.o' <<
'ggml-backend.o' <<
'ggml-quants.o' <<
'ggml-aarch64.o'
'ggml/src/ggml.o' <<
'ggml/src/ggml-alloc.o' <<
'ggml/src/ggml-backend.o' <<
'ggml/src/ggml-backend-reg.o' <<
'ggml/src/ggml-opt.o' <<
'ggml/src/ggml-quants.o' <<
'ggml/src/ggml-threading.o' <<
'ggml/src/ggml-cpu/ggml-cpu.o' <<
'ggml/src/ggml-cpu/ggml-cpu-cpp.o' <<
'ggml/src/ggml-cpu/ggml-cpu-aarch64.o' <<
'ggml/src/ggml-cpu/ggml-cpu-hbm.o' <<
'ggml/src/ggml-cpu/ggml-cpu-quants.o' <<
'ggml/src/ggml-cpu/ggml-cpu-traits.o'
$OBJ_WHISPER <<
'whisper.o'
'src/whisper.o'
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
$objs << "ruby_whisper.o"
@ -184,9 +186,12 @@ $LDFLAGS = "#{$MK_LDFLAGS} #{$LDFLAGS}"
create_makefile('whisper')
File.open 'Makefile', 'a' do |file|
file.puts 'include get-flags.mk'
file.puts 'include scripts/get-flags.mk'
file.puts 'include cpu.mk'
if $GGML_METAL
file.puts 'include metal.mk'
if $GGML_METAL_EMBED_LIBRARY
file.puts 'include metal-embed.mk'
end

View File

@ -1,14 +1,17 @@
ggml-metal-embed.o: \
ggml-metal.metal \
ggml-common.h
ggml/src/ggml-metal/ggml-metal-embed.o: \
ggml/src/ggml-metal/ggml-metal.metal \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/src/ggml-common.h
@echo "Embedding Metal library"
@sed -e '/#include "ggml-common.h"/r ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml-metal.metal > ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)
@echo ".incbin \"ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)
@$(AS) $(TEMP_ASSEMBLY) -o $@
@rm -f ${TEMP_ASSEMBLY}
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
@echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s
$(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@
@rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s
@rmdir ${TEMP_ASSEMBLY}

View File

@ -0,0 +1,6 @@
ggml/src/ggml-metal/ggml-metal.o: \
ggml/src/ggml-metal/ggml-metal.m \
ggml/src/ggml-metal/ggml-metal-impl.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h
$(CC) $(CFLAGS) -c $< -o $@

View File

@ -1,4 +1,5 @@
#include <ruby.h>
#include <ruby/memory_view.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
@ -35,11 +36,20 @@ extern "C" {
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
VALUE eError;
VALUE cSegment;
VALUE cModel;
static ID id_to_s;
static ID id_call;
static ID id___method__;
static ID id_to_enum;
static ID id_length;
static ID id_next;
static ID id_new;
static ID id_to_path;
static ID id_pre_converted_models;
static bool is_log_callback_finalized = false;
@ -100,13 +110,13 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
*/
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
VALUE old_callback = rb_iv_get(self, "@log_callback");
VALUE old_callback = rb_iv_get(self, "log_callback");
if (!NIL_P(old_callback)) {
rb_undefine_finalizer(old_callback);
}
rb_iv_set(self, "@log_callback", log_callback);
rb_iv_set(self, "@user_data", user_data);
rb_iv_set(self, "log_callback", log_callback);
rb_iv_set(self, "user_data", user_data);
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
rb_define_finalizer(log_callback, finalize_log_callback);
@ -115,8 +125,8 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
if (is_log_callback_finalized) {
return;
}
VALUE log_callback = rb_iv_get(mWhisper, "@log_callback");
VALUE udata = rb_iv_get(mWhisper, "@user_data");
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
VALUE udata = rb_iv_get(mWhisper, "user_data");
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
}, nullptr);
@ -181,6 +191,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
@ -189,7 +200,9 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
/*
* call-seq:
* new("base.en") -> Whisper::Context
* new("path/to/model.bin") -> Whisper::Context
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
*/
static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
@ -199,6 +212,14 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
@ -544,6 +565,168 @@ VALUE ruby_whisper_model_type(VALUE self) {
return rb_str_new2(whisper_model_type_readable(rw->context));
}
/*
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*
* call-seq:
* full(params, samples, n_samples) -> nil
* full(params, samples) -> nil
*
* The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
*/
VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
if (argc < 2 || argc > 3) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
rb_memory_view_t view;
const bool memory_view_available_p = rb_memory_view_available_p(samples);
if (argc == 3) {
n_samples = NUM2INT(argv[2]);
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) < n_samples) {
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else {
if (TYPE(samples) == T_ARRAY) {
n_samples = RARRAY_LEN(samples);
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
n_samples = view.byte_size / view.item_size;
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
}
}
float * c_samples = (float *)malloc(n_samples * sizeof(float));
if (memory_view_available_p) {
c_samples = (float *)view.data;
} else {
if (TYPE(samples) == T_ARRAY) {
for (int i = 0; i < n_samples; i++) {
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
}
} else {
// TODO: use rb_block_call
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
for (int i = 0; i < n_samples; i++) {
// TODO: check if iter is exhausted and raise ArgumentError appropriately
VALUE sample = rb_funcall(iter, id_next, 0);
c_samples[i] = RFLOAT_VALUE(sample);
}
}
}
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return Qnil;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
}
}
/*
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
* Result is stored in the default state of the context
* Not thread safe if executed in parallel on the same context.
* It seems this approach can offer some speedup in some cases.
* However, the transcription accuracy can be worse at the beginning and end of each chunk.
*
* call-seq:
* full_parallel(params, samples) -> nil
* full_parallel(params, samples, n_samples) -> nil
* full_parallel(params, samples, n_samples, n_processors) -> nil
* full_parallel(params, samples, nil, n_processors) -> nil
*/
static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
if (argc < 2 || argc > 4) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
int n_processors;
rb_memory_view_t view;
const bool memory_view_available_p = rb_memory_view_available_p(samples);
switch (argc) {
case 2:
n_processors = 1;
break;
case 3:
n_processors = 1;
break;
case 4:
n_processors = NUM2INT(argv[3]);
break;
}
if (argc >= 3 && !NIL_P(argv[2])) {
n_samples = NUM2INT(argv[2]);
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) < n_samples) {
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
n_samples = view.byte_size / view.item_size;
} else {
if (TYPE(samples) == T_ARRAY) {
n_samples = RARRAY_LEN(samples);
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
}
}
float * c_samples = (float *)malloc(n_samples * sizeof(float));
if (memory_view_available_p) {
c_samples = (float *)view.data;
} else {
if (TYPE(samples) == T_ARRAY) {
for (int i = 0; i < n_samples; i++) {
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
}
} else {
// FIXME: use rb_block_call
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
for (int i = 0; i < n_samples; i++) {
// TODO: check if iter is exhausted and raise ArgumentError
VALUE sample = rb_funcall(iter, id_next, 0);
c_samples[i] = RFLOAT_VALUE(sample);
}
}
}
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return Qnil;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
}
}
/*
* Number of segments.
*
@ -1078,6 +1261,25 @@ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
rwp->params.logprob_thold = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* no_speech_thold -> Float
*/
static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.no_speech_thold);
}
/*
* call-seq:
* no_speech_thold = threshold -> threshold
*/
static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
return value;
}
/*
* Sets new segment callback, called for every newly generated text segment.
*
@ -1174,9 +1376,6 @@ typedef struct {
VALUE context;
} ruby_whisper_model;
VALUE cSegment;
VALUE cModel;
static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
rb_gc_mark(rws->context);
}
@ -1518,15 +1717,61 @@ static VALUE ruby_whisper_c_model_type(VALUE self) {
return rb_str_new2(whisper_model_type_readable(rw->context));
}
static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
const int c_code = NUM2INT(code);
char *raw_message;
switch (c_code) {
case -2:
raw_message = "failed to compute log mel spectrogram";
break;
case -3:
raw_message = "failed to auto-detect language";
break;
case -4:
raw_message = "too many decoders requested";
break;
case -5:
raw_message = "audio_ctx is larger than the maximum allowed";
break;
case -6:
raw_message = "failed to encode";
break;
case -7:
raw_message = "whisper_kv_cache_init() failed for self-attention cache";
break;
case -8:
raw_message = "failed to decode";
break;
case -9:
raw_message = "failed to decode";
break;
default:
raw_message = "unknown error";
break;
}
const VALUE message = rb_str_new2(raw_message);
rb_call_super(1, &message);
rb_iv_set(self, "@code", code);
return self;
}
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
id___method__ = rb_intern("__method__");
id_to_enum = rb_intern("to_enum");
id_length = rb_intern("length");
id_next = rb_intern("next");
id_new = rb_intern("new");
id_to_path = rb_intern("to_path");
id_pre_converted_models = rb_intern("pre_converted_models");
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
@ -1564,6 +1809,8 @@ void Init_whisper() {
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
@ -1615,6 +1862,8 @@ void Init_whisper() {
rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
@ -1623,6 +1872,9 @@ void Init_whisper() {
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
rb_define_attr(eError, "code", true, false);
rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
// High leve
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);

View File

@ -0,0 +1,6 @@
require "yaml"
sources = `git ls-files -z ../..`.split("\x0")
paths = YAML.load_file("../../.github/workflows/bindings-ruby.yml")[true]["push"]["paths"]
paths.delete "bindings/ruby/**"
EXTSOURCES = (Dir.glob(paths, base: "../..").collect {|path| "../../#{path}"} << "../../LICENSE") & sources

View File

@ -1,31 +0,0 @@
---
- ../../src/whisper.cpp
- ../../include/whisper.h
- ../../ggml/src/ggml.c
- ../../ggml/src/ggml-cpu.c
- ../../ggml/src/ggml-impl.h
- ../../ggml/src/ggml-aarch64.h
- ../../ggml/src/ggml-aarch64.c
- ../../ggml/src/ggml-alloc.c
- ../../ggml/src/ggml-backend-impl.h
- ../../ggml/src/ggml-backend.cpp
- ../../ggml/src/ggml-common.h
- ../../ggml/src/ggml-quants.h
- ../../ggml/src/ggml-quants.c
- ../../ggml/src/ggml-cpu-impl.h
- ../../ggml/src/ggml-metal.m
- ../../ggml/src/ggml-metal.metal
- ../../ggml/src/ggml-blas.cpp
- ../../ggml/include/ggml.h
- ../../ggml/include/ggml-alloc.h
- ../../ggml/include/ggml-backend.h
- ../../ggml/include/ggml-cpu.h
- ../../ggml/include/ggml-cuda.h
- ../../ggml/include/ggml-kompute.h
- ../../ggml/include/ggml-metal.h
- ../../ggml/include/ggml-sycl.h
- ../../ggml/include/ggml-vulkan.h
- ../../ggml/include/ggml-blas.h
- ../../scripts/get-flags.mk
- ../../examples/dr_wav.h
- ../../LICENSE

View File

@ -0,0 +1,2 @@
require "whisper.so"
require "whisper/model/uri"

View File

@ -0,0 +1,157 @@
require "whisper.so"
require "uri"
require "net/http"
require "time"
require "pathname"
require "io/console/size"
class Whisper::Model
class URI
def initialize(uri)
@uri = URI(uri)
end
def to_path
cache
cache_path.to_path
end
def clear_cache
path = cache_path
path.delete if path.exist?
end
private
def cache_path
base_cache_dir/@uri.host/@uri.path[1..]
end
def base_cache_dir
base = case RUBY_PLATFORM
when /mswin|mingw/
ENV.key?("LOCALAPPDATA") ? Pathname(ENV["LOCALAPPDATA"]) : Pathname(Dir.home)/"AppData/Local"
when /darwin/
Pathname(Dir.home)/"Library/Caches"
else
ENV.key?("XDG_CACHE_HOME") ? ENV["XDG_CACHE_HOME"] : Pathname(Dir.home)/".cache"
end
base/"whisper.cpp"
end
def cache
path = cache_path
headers = {}
headers["if-modified-since"] = path.mtime.httpdate if path.exist?
request @uri, headers
path
end
def request(uri, headers)
Net::HTTP.start uri.host, uri.port, use_ssl: uri.scheme == "https" do |http|
request = Net::HTTP::Get.new(uri, headers)
http.request request do |response|
case response
when Net::HTTPNotModified
# noop
when Net::HTTPOK
download response
when Net::HTTPRedirection
request URI(response["location"]), headers
else
return if headers.key?("if-modified-since") # Use cache file
raise "#{response.code} #{response.message}\n#{response.body}"
end
end
end
end
def download(response)
path = cache_path
path.dirname.mkpath unless path.dirname.exist?
downloading_path = Pathname("#{path}.downloading")
size = response.content_length
downloading_path.open "wb" do |file|
downloaded = 0
response.read_body do |chunk|
file << chunk
downloaded += chunk.bytesize
show_progress downloaded, size
end
end
downloading_path.rename path
end
def show_progress(current, size)
return unless $stderr.tty?
return unless size
unless @prev
@prev = Time.now
$stderr.puts "Downloading #{@uri}"
end
now = Time.now
return if now - @prev < 1 && current < size
progress_width = 20
progress = current.to_f / size
arrow_length = progress * progress_width
arrow = "=" * (arrow_length - 1) + ">" + " " * (progress_width - arrow_length)
line = "[#{arrow}] (#{format_bytesize(current)} / #{format_bytesize(size)})"
padding = ' ' * ($stderr.winsize[1] - line.size)
$stderr.print "\r#{line}#{padding}"
$stderr.puts if current >= size
@prev = now
end
def format_bytesize(bytesize)
return "0.0 B" if bytesize.zero?
units = %w[B KiB MiB GiB TiB]
exp = (Math.log(bytesize) / Math.log(1024)).to_i
format("%.1f %s", bytesize.to_f / 1024 ** exp, units[exp])
end
end
@pre_converted_models = {}
%w[
tiny
tiny.en
tiny-q5_1
tiny.en-q5_1
tiny-q8_0
base
base.en
base-q5_1
base.en-q5_1
base-q8_0
small
small.en
small.en-tdrz
small-q5_1
small.en-q5_1
small-q8_0
medium
medium.en
medium-q5_0
medium.en-q5_0
medium-q8_0
large-v1
large-v2
large-v2-q5_0
large-v2-q8_0
large-v3
large-v3-q5_0
large-v3-turbo
large-v3-turbo-q5_0
large-v3-turbo-q8_0
].each do |name|
@pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
end
class << self
attr_reader :pre_converted_models
end
end

View File

@ -1,7 +1,7 @@
require "test/unit"
require "whisper"
require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
end

View File

@ -0,0 +1,5 @@
Makefile
jfk_reader.o
jfk_reader.so
jfk_reader.bundle
jfk_reader.dll

View File

@ -0,0 +1,3 @@
require "mkmf"
create_makefile("jfk_reader")

View File

@ -0,0 +1,68 @@
#include <ruby.h>
#include <ruby/memory_view.h>
#include <ruby/encoding.h>
static VALUE
jfk_reader_initialize(VALUE self, VALUE audio_path)
{
rb_iv_set(self, "audio_path", audio_path);
return Qnil;
}
static bool
jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags)
{
VALUE audio_path = rb_iv_get(obj, "audio_path");
const char *audio_path_str = StringValueCStr(audio_path);
const int n_samples = 176000;
float *data = (float *)malloc(n_samples * sizeof(float));
short *samples = (short *)malloc(n_samples * sizeof(short));
FILE *file = fopen(audio_path_str, "rb");
fseek(file, 78, SEEK_SET);
fread(samples, sizeof(short), n_samples, file);
fclose(file);
for (int i = 0; i < n_samples; i++) {
data[i] = samples[i]/32768.0;
}
view->obj = obj;
view->data = (void *)data;
view->byte_size = sizeof(float) * n_samples;
view->readonly = true;
view->format = "f";
view->item_size = sizeof(float);
view->item_desc.components = NULL;
view->item_desc.length = 0;
view->ndim = 1;
view->shape = NULL;
view->sub_offsets = NULL;
view->private_data = NULL;
return true;
}
static bool
jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view)
{
return true;
}
static bool
jfk_reader_memory_view_available_p(const VALUE obj)
{
return true;
}
static const rb_memory_view_entry_t jfk_reader_view_entry = {
jfk_reader_get_memory_view,
jfk_reader_release_memory_view,
jfk_reader_memory_view_available_p
};
void Init_jfk_reader(void)
{
VALUE cJFKReader = rb_define_class("JFKReader", rb_cObject);
rb_memory_view_register(cJFKReader, &jfk_reader_view_entry);
rb_define_method(cJFKReader, "initialize", jfk_reader_initialize, 1);
}

View File

@ -1,14 +1,11 @@
require "test/unit"
require "whisper"
class TestCallback < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
require_relative "helper"
class TestCallback < TestBase
def setup
GC.start
@params = Whisper::Params.new
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper = Whisper::Context.new("base.en")
@audio = File.join(AUDIO)
end
def test_new_segment_callback

View File

@ -0,0 +1,20 @@
require_relative "helper"
class TestError < TestBase
def test_error
error = Whisper::Error.new(-2)
assert_equal "failed to compute log mel spectrogram", error.message
assert_equal -2, error.code
end
def test_unknown_error
error = Whisper::Error.new(-20)
assert_equal "unknown error", error.message
end
def test_non_int_code
assert_raise TypeError do
error = Whisper::Error.new("non int")
end
end
end

View File

@ -1,13 +1,14 @@
require_relative "helper"
require "pathname"
class TestModel < TestBase
def test_model
whisper = Whisper::Context.new(MODEL)
whisper = Whisper::Context.new("base.en")
assert_instance_of Whisper::Model, whisper.model
end
def test_attributes
whisper = Whisper::Context.new(MODEL)
whisper = Whisper::Context.new("base.en")
model = whisper.model
assert_equal 51864, model.n_vocab
@ -25,7 +26,7 @@ class TestModel < TestBase
end
def test_gc
model = Whisper::Context.new(MODEL).model
model = Whisper::Context.new("base.en").model
GC.start
assert_equal 51864, model.n_vocab
@ -41,4 +42,30 @@ class TestModel < TestBase
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_pathname
path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
whisper = Whisper::Context.new(path)
model = whisper.model
assert_equal 51864, model.n_vocab
assert_equal 1500, model.n_audio_ctx
assert_equal 512, model.n_audio_state
assert_equal 8, model.n_audio_head
assert_equal 6, model.n_audio_layer
assert_equal 448, model.n_text_ctx
assert_equal 512, model.n_text_state
assert_equal 8, model.n_text_head
assert_equal 6, model.n_text_layer
assert_equal 80, model.n_mels
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_auto_download
path = Whisper::Model.pre_converted_models["base.en"].to_path
assert_path_exist path
assert_equal 147964211, File.size(path)
end
end

View File

@ -151,4 +151,10 @@ class TestParams < TestBase
@params.logprob_thold = -0.5
assert_in_delta -0.5, @params.logprob_thold
end
def test_no_speech_thold
assert_in_delta 0.6, @params.no_speech_thold
@params.no_speech_thold = 0.2
assert_in_delta 0.2, @params.no_speech_thold
end
end

View File

@ -5,7 +5,7 @@ class TestSegment < TestBase
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(TestBase::MODEL)
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)

View File

@ -1,5 +1,6 @@
require_relative "helper"
require "stringio"
require "etc"
# Exists to detect memory-related bug
Whisper.log_set ->(level, buffer, user_data) {}, nil
@ -10,7 +11,7 @@ class TestWhisper < TestBase
end
def test_whisper
@whisper = Whisper::Context.new(MODEL)
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@ -24,7 +25,7 @@ class TestWhisper < TestBase
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(TestBase::MODEL)
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
@ -103,11 +104,11 @@ class TestWhisper < TestBase
logs << [level, buffer, udata]
}
Whisper.log_set log_callback, user_data
Whisper::Context.new(MODEL)
Whisper::Context.new("base.en")
assert logs.length > 30
logs.each do |log|
assert_equal Whisper::LOG_LEVEL_INFO, log[0]
assert_include [Whisper::LOG_LEVEL_DEBUG, Whisper::LOG_LEVEL_INFO, Whisper::LOG_LEVEL_WARN], log[0]
assert_same user_data, log[2]
end
end
@ -119,9 +120,107 @@ class TestWhisper < TestBase
}, nil
dev = StringIO.new("")
$stderr = dev
Whisper::Context.new(MODEL)
Whisper::Context.new("base.en")
assert_empty dev.string
ensure
$stderr = stderr
end
sub_test_case "full" do
def setup
super
@whisper = Whisper::Context.new("base.en")
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
end
def test_full
@whisper.full(@params, @samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_without_length
@whisper.full(@params, @samples)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_enumerator
samples = @samples.each
@whisper.full(@params, samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_enumerator_without_length
samples = @samples.each
assert_raise ArgumentError do
@whisper.full(@params, samples)
end
end
def test_full_enumerator_with_too_large_length
samples = @samples.each.take(10).to_enum
assert_raise StopIteration do
@whisper.full(@params, samples, 11)
end
end
def test_full_with_memory_view
samples = JFKReader.new(AUDIO)
@whisper.full(@params, samples)
assert_equal 1, @whisper.full_n_segments
assert_match /ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text
end
def test_full_parallel
@whisper.full_parallel(@params, @samples, @samples.length, Etc.nprocessors)
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_with_memory_view
samples = JFKReader.new(AUDIO)
@whisper.full_parallel(@params, samples, nil, Etc.nprocessors)
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_length_and_n_processors
@whisper.full_parallel(@params, @samples)
assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_length
@whisper.full_parallel(@params, @samples, nil, Etc.nprocessors)
assert_equal Etc.nprocessors, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
def test_full_parallel_without_n_processors
@whisper.full_parallel(@params, @samples, @samples.length)
assert_equal 1, @whisper.full_n_segments
text = @whisper.each_segment.collect(&:text).join
assert_match /ask what you can do/i, text
assert_match /for your country/i, text
end
end
end

View File

@ -1,36 +1,36 @@
require "yaml"
require_relative "extsources"
Gem::Specification.new do |s|
s.name = "whispercpp"
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
s.version = '1.3.0'
s.date = '2024-05-14'
s.version = '1.3.1'
s.date = '2024-12-19'
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
s.email = 'todd.fisher@gmail.com'
s.extra_rdoc_files = ['LICENSE', 'README.md']
s.files = `git ls-files . -z`.split("\x0") +
YAML.load_file("extsources.yaml").collect {|file|
EXTSOURCES.collect {|file|
basename = File.basename(file)
if s.extra_rdoc_files.include?(basename)
basename
else
File.join("ext", basename)
file.sub("../..", "ext")
end
}
s.summary = %q{Ruby whisper.cpp bindings}
s.test_files = ["tests/test_whisper.rb"]
s.test_files = s.files.select {|file| file.start_with? "tests/"}
s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0'
#### Documentation and testing.
s.homepage = 'https://github.com/ggerganov/whisper.cpp'
s.rdoc_options = ['--main', '../../README.md']
s.rdoc_options = ['--main', 'README.md']
s.platform = Gem::Platform::RUBY
s.licenses = ['MIT']
end

View File

@ -1,10 +1,10 @@
prefix=@CMAKE_INSTALL_PREFIX@
exec_prefix=${prefix}
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
libdir=${exec_prefix}/lib
includedir=${prefix}/include
Name: whisper
Description: Port of OpenAI's Whisper model in C/C++
Version: @PROJECT_VERSION@
Libs: -L${libdir} -lwhisper
Libs: -L${libdir} -lggml -lggml-base -lwhisper
Cflags: -I${includedir}

View File

@ -97,52 +97,29 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN)
add_subdirectory(whisper.wasm)
set_target_properties(libmain PROPERTIES FOLDER "libs")
add_subdirectory(stream.wasm)
set_target_properties(libstream PROPERTIES FOLDER "libs")
add_subdirectory(command.wasm)
set_target_properties(libcommand PROPERTIES FOLDER "libs")
#add_subdirectory(talk.wasm)
#set_target_properties(libtalk PROPERTIES FOLDER "libs")
add_subdirectory(bench.wasm)
set_target_properties(libbench PROPERTIES FOLDER "libs")
elseif(CMAKE_JS_VERSION)
add_subdirectory(addon.node)
set_target_properties(addon.node PROPERTIES FOLDER "examples")
else()
add_subdirectory(main)
set_target_properties(main PROPERTIES FOLDER "examples")
if (WHISPER_SDL2)
add_subdirectory(stream)
set_target_properties(stream PROPERTIES FOLDER "examples")
endif (WHISPER_SDL2)
add_subdirectory(server)
set_target_properties(server PROPERTIES FOLDER "examples")
if (WHISPER_SDL2)
add_subdirectory(command)
set_target_properties(command PROPERTIES FOLDER "examples")
endif (WHISPER_SDL2)
add_subdirectory(cli)
add_subdirectory(bench)
set_target_properties(bench PROPERTIES FOLDER "examples")
add_subdirectory(server)
add_subdirectory(quantize)
set_target_properties(quantize PROPERTIES FOLDER "examples")
if (WHISPER_SDL2)
# TODO: disabled until update
# https://github.com/ggerganov/whisper.cpp/issues/1818
#add_subdirectory(talk)
#set_target_properties(talk PROPERTIES FOLDER "examples")
add_subdirectory(talk-llama)
set_target_properties(talk-llama PROPERTIES FOLDER "examples")
add_subdirectory(lsp)
set_target_properties(lsp PROPERTIES FOLDER "examples")
if (GGML_SYCL)
add_subdirectory(sycl)
set_target_properties(ls-sycl-device PROPERTIES FOLDER "examples")
endif()
endif (WHISPER_SDL2)
if (WHISPER_SDL2)
add_subdirectory(stream)
add_subdirectory(command)
add_subdirectory(talk-llama)
add_subdirectory(lsp)
if (GGML_SYCL)
add_subdirectory(sycl)
endif()
endif (WHISPER_SDL2)
add_subdirectory(deprecation-warning)
endif()
if (WHISPER_SDL2)
add_subdirectory(wchess)
set_target_properties(wchess PROPERTIES FOLDER "examples")
endif (WHISPER_SDL2)

View File

@ -1,6 +1,8 @@
set(TARGET bench)
set(TARGET whisper-bench)
add_executable(${TARGET} bench.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE whisper ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -1,4 +1,4 @@
# bench
# whisper.cpp/examples/bench
A very basic tool for benchmarking the inference performance on your device. The tool simply runs the Encoder part of
the transformer on some random audio data and records the execution time. This way we can have an objective comparison
@ -7,11 +7,8 @@ of the performance of the model for various setups.
Benchmark results are tracked in the following Github issue: https://github.com/ggerganov/whisper.cpp/issues/89
```bash
# build the bench tool
$ make bench
# run it on the small.en model using 4 threads
$ ./bench -m ./models/ggml-small.en.bin -t 4
# run the bench too on the small.en model using 4 threads
$ ./build/bin/whisper-bench -m ./models/ggml-small.en.bin -t 4
whisper_model_load: loading model from './models/ggml-small.en.bin'
whisper_model_load: n_vocab = 51864

View File

@ -1,6 +1,8 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
set(TARGET whisper-cli)
add_executable(${TARGET} cli.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common whisper ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)

View File

@ -1,12 +1,12 @@
# main
# whisper.cpp/examples/cli
This is the main example demonstrating most of the functionality of the Whisper model.
It can be used as a reference for using the `whisper.cpp` library in other projects.
```
./main -h
./build/bin/whisper-cli -h
usage: ./main [options] file0.wav file1.wav ...
usage: ./build-pkg/bin/whisper-cli [options] file0.wav file1.wav ...
options:
-h, --help [default] show this help message and exit
@ -20,9 +20,12 @@ options:
-sow, --split-on-word [false ] split on word rather than on token
-bo N, --best-of N [5 ] number of best candidates to keep
-bs N, --beam-size N [5 ] beam size for beam search
-ac N, --audio-ctx N [0 ] audio context size (0 - all)
-wt N, --word-thold N [0.01 ] word timestamp probability threshold
-et N, --entropy-thold N [2.40 ] entropy threshold for decoder fail
-lpt N, --logprob-thold N [-1.00 ] log probability threshold for decoder fail
-tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1
-tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1
-debug, --debug-mode [false ] enable debug mode (eg. dump log_mel)
-tr, --translate [false ] translate from source language to english
-di, --diarize [false ] stereo audio diarization
@ -38,16 +41,23 @@ options:
-oj, --output-json [false ] output result in a JSON file
-ojf, --output-json-full [false ] include more information in the JSON file
-of FNAME, --output-file FNAME [ ] output file path (without file extension)
-np, --no-prints [false ] do not print anything other than the results
-ps, --print-special [false ] print special tokens
-pc, --print-colors [false ] print colors
-pp, --print-progress [false ] print progress
-nt, --no-timestamps [false ] do not print timestamps
-l LANG, --language LANG [en ] spoken language ('auto' for auto-detect)
-dl, --detect-language [false ] exit after automatically detecting language
--prompt PROMPT [ ] initial prompt
--prompt PROMPT [ ] initial prompt (max n_text_ctx/2 tokens)
-m FNAME, --model FNAME [models/ggml-base.en.bin] model path
-f FNAME, --file FNAME [ ] input WAV file path
-oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference
-dtw MODEL --dtw MODEL [ ] compute token-level timestamps
-ls, --log-score [false ] log best decoder scores of tokens
-ng, --no-gpu [false ] disable GPU
-fa, --flash-attn [false ] flash attention
--suppress-regex REGEX [ ] regular expression matching tokens to suppress
--grammar GRAMMAR [ ] GBNF grammar to guide decoding
--grammar-rule RULE [ ] top-level GBNF grammar rule name
--grammar-penalty N [100.0 ] scales down logits of nongrammar tokens
```

View File

@ -1,9 +1,10 @@
if (WHISPER_SDL2)
# command
set(TARGET command)
set(TARGET whisper-command)
add_executable(${TARGET} command.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)
endif ()

View File

@ -1,14 +1,14 @@
# command
# whisper.cpp/examples/command
This is a basic Voice Assistant example that accepts voice commands from the microphone.
More info is available in [issue #171](https://github.com/ggerganov/whisper.cpp/issues/171).
```bash
# Run with default arguments and small model
./command -m ./models/ggml-small.en.bin -t 8
./whisper-command -m ./models/ggml-small.en.bin -t 8
# On Raspberry Pi, use tiny or base models + "-ac 768" for better performance
./command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
./whisper-command -m ./models/ggml-tiny.en.bin -ac 768 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
@ -23,10 +23,10 @@ Initial tests show that this approach might be extremely efficient in terms of p
```bash
# Run in guided mode, the list of allowed commands is in commands.txt
./command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
./whisper-command -m ./models/ggml-base.en.bin -cmd ./examples/command/commands.txt
# On Raspberry Pi, in guided mode you can use "-ac 128" for extra performance
./command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
./whisper-command -m ./models/ggml-tiny.en.bin -cmd ./examples/command/commands.txt -ac 128 -t 3 -c 0
```
https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9b8b-aeeb76bee969.mp4
@ -34,7 +34,7 @@ https://user-images.githubusercontent.com/1991296/207435352-8fc4ed3f-bde5-4555-9
## Building
The `command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
The `whisper-command` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
@ -47,5 +47,6 @@ sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
make command
cmake -B build -DWHISPER_SDL2=ON
cmake --build build --config Release
```

View File

@ -72,9 +72,6 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_MOSTLY_IQ4_XS:
case GGML_FTYPE_MOSTLY_IQ1_M:
case GGML_FTYPE_MOSTLY_BF16:
case GGML_FTYPE_MOSTLY_Q4_0_4_4:
case GGML_FTYPE_MOSTLY_Q4_0_4_8:
case GGML_FTYPE_MOSTLY_Q4_0_8_8:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
@ -212,9 +209,6 @@ bool ggml_common_quantize_0(
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_BF16:
case GGML_TYPE_Q4_0_4_4:
case GGML_TYPE_Q4_0_4_8:
case GGML_TYPE_Q4_0_8_8:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_COUNT:

View File

@ -1,5 +1,7 @@
#include "common-sdl.h"
#include <cstdio>
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;

View File

@ -0,0 +1,4 @@
add_executable(main ./deprecation-warning.cpp)
add_executable(bench ./deprecation-warning.cpp)
add_executable(stream ./deprecation-warning.cpp)
add_executable(command ./deprecation-warning.cpp)

View File

@ -0,0 +1,17 @@
# Migration notice for binary filenames
> [!IMPORTANT]
[2024 Dec 20] Binaries have been renamed w/ a `whisper-` prefix. `main` is now `whisper-cli`, `server` is `whisper-server`, etc (https://github.com/ggerganov/whisper.cpp/pull/2648)
This migration was important, but it is a breaking change that may not always be immediately obvious to users.
Please update all scripts and workflows to use the new binary names.
| Old Filename | New Filename |
| ---- | ---- |
| main | whisper-cli |
| bench | whisper-bench |
| stream | whisper-stream |
| command | whisper-command |
| server | whisper-server |
| talk-llama | whisper-talk-llama |

View File

@ -0,0 +1,34 @@
// Warns users that this filename was deprecated, and provides a link for more information.
#include <cstdio>
#include <string>
// Main
int main(int argc, char** argv) {
std::string filename = "main";
if (argc >= 1) {
filename = argv[0];
}
// Get only the program name from the full path
size_t pos = filename.find_last_of("/\\");
if (pos != std::string::npos) {
filename = filename.substr(pos+1);
}
// Append "whisper-" to the beginning of filename to get the replacemnt filename
std::string replacement_filename = "whisper-" + filename;
// The exception is if the filename is "main", then our replacement filename is "whisper-cli"
if (filename == "main") {
replacement_filename = "whisper-cli";
}
fprintf(stdout, "\n");
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str());
fprintf(stdout, " See https://github.com/ggerganov/whisper.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n");
fprintf(stdout, "\n");
return EXIT_FAILURE;
}

View File

@ -11,7 +11,7 @@
# Press Ctrl+C to stop recording
#
executable="./main"
executable="./build/bin/whisper-cli"
model="base.en"
model_path="models/ggml-$model.bin"
@ -46,7 +46,7 @@ ffmpeg -y -i ./rec.wav -ar 16000 -ac 1 -c:a pcm_s16le ./rec16.wav > /dev/null 2>
# run Whisper
echo "Processing ..."
./main -m models/ggml-base.en.bin rec16.wav -owts > /dev/null 2>&1
${executable} -m models/ggml-base.en.bin rec16.wav -owts > /dev/null 2>&1
# generate Karaoke video
echo "Generating video ..."

View File

@ -14,7 +14,7 @@ model="base.en"
check_requirements()
{
if ! command -v ./main &>/dev/null; then
if ! command -v ./build/bin/whisper-cli &>/dev/null; then
echo "whisper.cpp main executable is required (make)"
exit 1
fi
@ -100,7 +100,7 @@ while [ $running -eq 1 ]; do
err=$(cat /tmp/whisper-live.err | wc -l)
done
./main -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
./build/bin/whisper-cli -t 8 -m ./models/ggml-${model}.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step_s)) ]; do
sleep 1
@ -109,4 +109,4 @@ while [ $running -eq 1 ]; do
done
killall -v ffmpeg
killall -v main
killall -v whisper-cli

View File

@ -1,4 +1,4 @@
set(TARGET server)
set(TARGET whisper-server)
add_executable(${TARGET} server.cpp httplib.h)
include(DefaultTargetOptions)
@ -8,3 +8,5 @@ target_link_libraries(${TARGET} PRIVATE common json_cpp whisper ${CMAKE_THREAD_L
if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32)
endif()
install(TARGETS ${TARGET} RUNTIME)

View File

@ -1,4 +1,4 @@
# whisper.cpp http server
# whisper.cpp/examples/server
Simple http server. WAV Files are passed to the inference model via http requests.
@ -7,9 +7,9 @@ https://github.com/ggerganov/whisper.cpp/assets/1991296/e983ee53-8741-4eb5-9048-
## Usage
```
./server -h
./build/bin/whisper-server -h
usage: ./bin/server [options]
usage: ./build/bin/whisper-server [options]
options:
-h, --help [default] show this help message and exit

View File

@ -677,7 +677,8 @@ int main(int argc, char ** argv) {
if (sparams.ffmpeg_converter) {
// if file is not wav, convert to wav
// write to temporary file
const std::string temp_filename_base = std::tmpnam(nullptr);
//const std::string temp_filename_base = std::tmpnam(nullptr);
const std::string temp_filename_base = "whisper-server-tmp"; // TODO: this is a hack, remove when the mutext is removed
const std::string temp_filename = temp_filename_base + ".wav";
std::ofstream temp_file{temp_filename, std::ios::binary};
temp_file << audio_file.content;
@ -711,7 +712,6 @@ int main(int argc, char ** argv) {
}
}
printf("Successfully loaded %s\n", filename.c_str());
// print system information

View File

@ -1,9 +1,10 @@
if (WHISPER_SDL2)
# stream
set(TARGET stream)
set(TARGET whisper-stream)
add_executable(${TARGET} stream.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ${TARGET} RUNTIME)
endif ()

View File

@ -1,11 +1,11 @@
# stream
# whisper.cpp/examples/stream
This is a naive example of performing real-time inference on audio from your microphone.
The `stream` tool samples the audio every half a second and runs the transcription continously.
The `whisper-stream` tool samples the audio every half a second and runs the transcription continously.
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
```bash
./stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000
```
https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a80f-28ba83be7d09.mp4
@ -15,7 +15,7 @@ https://user-images.githubusercontent.com/1991296/194935793-76afede7-cfa8-48d8-a
Setting the `--step` argument to `0` enables the sliding window mode:
```bash
./stream -m ./models/ggml-small.en.bin -t 6 --step 0 --length 30000 -vth 0.6
./build/bin/whisper-stream -m ./models/ggml-base.en.bin -t 6 --step 0 --length 30000 -vth 0.6
```
In this mode, the tool will transcribe only after some speech activity is detected. A very
@ -27,7 +27,7 @@ a transcription block that is suitable for parsing.
## Building
The `stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
The `whisper-stream` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
@ -40,21 +40,10 @@ sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
make stream
```
cmake -B build -DWHISPER_SDL2=ON
cmake --build build --config Release
Ensure you are at the root of the repo when running `make stream`. Not within the `examples/stream` dir
as the libraries needed like `common-sdl.h` are located within `examples`. Attempting to compile within
`examples/steam` means your compiler cannot find them and it gives an error it cannot find the file.
```bash
whisper.cpp/examples/stream$ make stream
g++ stream.cpp -o stream
stream.cpp:6:10: fatal error: common/sdl.h: No such file or directory
6 | #include "common/sdl.h"
| ^~~~~~~~~~~~~~
compilation terminated.
make: *** [<builtin>: stream] Error 1
./build/bin/whisper-stream
```
## Web version

View File

@ -1,6 +1,5 @@
if (WHISPER_SDL2)
# talk-llama
set(TARGET talk-llama)
set(TARGET whisper-talk-llama)
add_executable(${TARGET} talk-llama.cpp
llama.cpp
llama-vocab.cpp

View File

@ -1,4 +1,4 @@
# talk-llama
# whisper.cpp/examples/talk-llama
Talk with an LLaMA AI in your terminal
@ -12,7 +12,7 @@ https://github.com/ggerganov/whisper.cpp/assets/1991296/d97a3788-bf2a-4756-9a43-
## Building
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
The `whisper-talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
@ -25,11 +25,12 @@ sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk-llama" executable
make talk-llama
# Build the "whisper-talk-llama" executable
cmake -B build -S . -DWHISPER_SDL2=ON
cmake --build build --config Release
# Run it
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
./build/bin/whisper-talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
@ -37,16 +38,16 @@ make talk-llama
## Session
The `talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
The `whisper-talk-llama` tool supports session management to enable more coherent and continuous conversations. By maintaining context from previous interactions, it can better understand and respond to user requests in a more natural way.
To enable session support, use the `--session FILE` command line option when running the program. The `talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
To enable session support, use the `--session FILE` command line option when running the program. The `whisper-talk-llama` model state will be saved to the specified file after each interaction. If the file does not exist, it will be created. If the file exists, the model state will be loaded from it, allowing you to resume a previous session.
This feature is especially helpful for maintaining context in long conversations or when interacting with the AI assistant across multiple sessions. It ensures that the assistant remembers the previous interactions and can provide more relevant and contextual responses.
Example usage:
```bash
./talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
./build/bin/whisper-talk-llama --session ./my-session-file -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/llama-13b/ggml-model-q4_0.gguf -p "Georgi" -t 8
```
## TTS

View File

@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
// penalties
struct llama_sampler_penalties {
const int32_t n_vocab;
const llama_token special_eos_id;
const llama_token linefeed_id;
const int32_t penalty_last_n;
const float penalty_repeat;
const float penalty_freq;
const float penalty_present;
const bool penalize_nl;
const bool ignore_eos;
ring_buffer<llama_token> prev;
// a frequency map to count token occurrences
std::unordered_map<llama_token, int> token_count;
};
static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
return;
}
ctx->token_count[token]++;
// if the ring buffer is full, remove the oldest token
if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
const auto old = ctx->prev.front();
ctx->token_count[old]--;
if (ctx->token_count[old] == 0) {
ctx->token_count.erase(old);
}
}
ctx->prev.push_back(token);
#if 0
// sanity check
std::unordered_map<llama_token, int> tmp;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
tmp[ctx->prev.rat(i)]++;
}
assert(ctx->token_count == tmp);
#endif
}
static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
if (ctx->ignore_eos) {
assert(ctx->special_eos_id >= 0);
// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
cur_p->data[ctx->special_eos_id].logit = -INFINITY;
} else {
// else, search for the special EOS token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->special_eos_id) {
cur_p->data[i].logit = -INFINITY;
break;
}
}
}
}
if ((ctx->penalty_last_n == 0) ||
(ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
return;
}
bool nl_found = false;
size_t nl_idx = 0;
float nl_logit = -INFINITY;
if (!ctx->penalize_nl) {
assert(ctx->linefeed_id >= 0);
// optimistically check if the candidates are not yet sorted/shuffled/truncated
if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = ctx->linefeed_id;
nl_logit = cur_p->data[ctx->linefeed_id].logit;
} else {
// else, search for the linefeed token
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id == ctx->linefeed_id) {
nl_found = true;
nl_idx = i;
nl_logit = cur_p->data[i].logit;
break;
}
}
}
}
// Create a frequency map to count occurrences of each token in last_tokens
// TODO: optimize this by maintaining the token count in the sampler context
using llama_token_cnt = std::unordered_map<llama_token, int>;
llama_token_cnt token_count;
for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
token_count[ctx->prev.rat(i)]++;
}
// Apply frequency and presence penalties to the cur_p
for (size_t i = 0; i < cur_p->size; ++i) {
const auto token_iter = token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) {
const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
if (token_iter == ctx->token_count.end()) {
continue;
}
const int count = token_iter->second;
assert(count > 0 && count <= ctx->penalty_last_n);
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (cur_p->data[i].logit <= 0) {
@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
}
cur_p->sorted = false;
if (!ctx->penalize_nl && nl_found) {
// restore the logit of the newline token if it was penalized
cur_p->data[nl_idx].logit = nl_logit;
}
}
static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_penalties *) smpl->ctx;
ctx->prev.clear();
ctx->token_count.clear();
}
static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
auto * result = llama_sampler_init_penalties(
ctx->n_vocab,
ctx->special_eos_id,
ctx->linefeed_id,
ctx->penalty_last_n,
ctx->penalty_repeat,
ctx->penalty_freq,
ctx->penalty_present,
ctx->penalize_nl,
ctx->ignore_eos);
ctx->penalty_present);
// copy the state
{
@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
};
struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab,
llama_token special_eos_id,
llama_token linefeed_id,
int32_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present,
bool penalize_nl,
bool ignore_eos) {
if (linefeed_id == LLAMA_TOKEN_NULL) {
penalize_nl = true;
}
if (special_eos_id == LLAMA_TOKEN_NULL) {
ignore_eos = false;
}
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .n_vocab = */ n_vocab,
/* .special_eos_id = */ special_eos_id,
/* .linefeed_id = */ linefeed_id,
/* .penalty_last_n = */ penalty_last_n,
/* .penalty_repeat = */ penalty_repeat,
/* .penalty_freq = */ penalty_freq,
/* .penalty_present = */ penalty_present,
/* .penalize_nl = */ penalize_nl,
/* .ignore_eos = */ ignore_eos,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
},
};
}
@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>());
} else {
size_t word_len = word.size(), str_len = str.size();
size_t word_len = word.size();
size_t str_len = str.size();
size_t pos = -1;
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
bool match = true;

View File

@ -418,6 +418,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
case LLAMA_VOCAB_PRE_TYPE_MINERVA:
regex_exprs = {
"\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
@ -737,7 +738,7 @@ struct llm_tokenizer_wpm_session {
std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt);
const auto flags = unicode_cpt_flags_from_cpt(cpt);
if (flags.is_whitespace) {
if (words.back().size()) { // finish previous word if any

File diff suppressed because it is too large Load Diff

View File

@ -104,12 +104,15 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
};
enum llama_rope_type {
LLAMA_ROPE_TYPE_NONE = -1,
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
LLAMA_ROPE_TYPE_NONE = -1,
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
};
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
@ -171,9 +174,9 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
//LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack
//LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
@ -185,7 +188,8 @@ extern "C" {
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE,
};
enum llama_pooling_type {
@ -272,6 +276,9 @@ extern "C" {
};
struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices;
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
@ -451,6 +458,7 @@ extern "C" {
// Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure
// - The output string is always null-terminated and cleared on failure
// - When retrieving a string, an extra byte must be allocated to account for the null terminator
// - GGUF array values are not supported by these functions
// Get metadata value as a string by key name
@ -667,6 +675,9 @@ extern "C" {
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
//
// State / sessions
//
@ -984,6 +995,9 @@ extern "C" {
char * buf,
int32_t length);
// Get list of built-in chat templates
LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len);
//
// Sampling API
//
@ -1125,16 +1139,12 @@ extern "C" {
const char * grammar_str,
const char * grammar_root);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab, // llama_n_vocab()
llama_token special_eos_id, // llama_token_eos()
llama_token linefeed_id, // llama_token_nl()
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
float penalty_present, // 0.0 = disabled
bool penalize_nl, // consider newlines as a repeatable token
bool ignore_eos); // ignore the end-of-sequence token
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat, // 1.0 = disabled
float penalty_freq, // 0.0 = disabled
float penalty_present); // 0.0 = disabled
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry(
@ -1244,8 +1254,6 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
#ifdef __cplusplus
}
#endif

View File

@ -71,15 +71,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
throw std::invalid_argument("failed to convert utf8 to codepoint");
}
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) {
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
// std::vector<uint16_t> result;
// if (/* 0x0000 <= cp && */ cp <= 0xffff) {
// result.emplace_back(cp);
// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
// result.emplace_back(cpt);
// return result;
// }
// if (0x10000 <= cp && cp <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
// if (0x10000 <= cpt && cpt <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
// return result;
// }
// throw std::invalid_argument("failed to convert codepoint to utf16");
@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result;
//}
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
@ -201,7 +201,18 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
}
static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
return conv.from_bytes(s);
}
@ -242,8 +253,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
@ -360,8 +371,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
};
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
};
size_t _prev_end = offset_ini;
@ -561,29 +572,29 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
// interface
//
std::string unicode_cpt_to_utf8(uint32_t cp) {
std::string unicode_cpt_to_utf8(uint32_t cpt) {
std::string result;
if (/* 0x00 <= cp && */ cp <= 0x7f) {
result.push_back(cp);
if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cpt);
return result;
}
if (0x80 <= cp && cp <= 0x7ff) {
result.push_back(0xc0 | ((cp >> 6) & 0x1f));
result.push_back(0x80 | (cp & 0x3f));
if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x800 <= cp && cp <= 0xffff) {
result.push_back(0xe0 | ((cp >> 12) & 0x0f));
result.push_back(0x80 | ((cp >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f));
if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
if (0x10000 <= cp && cp <= 0x10ffff) {
result.push_back(0xf0 | ((cp >> 18) & 0x07));
result.push_back(0x80 | ((cp >> 12) & 0x3f));
result.push_back(0x80 | ((cp >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f));
if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cpt & 0x3f));
return result;
}
@ -613,19 +624,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result;
}
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
}
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
if (utf8.empty()) {
return undef; // undefined
}
size_t offset = 0;
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
}
std::string unicode_byte_to_utf8(uint8_t byte) {
@ -638,41 +649,41 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
return map.at(utf8);
}
uint32_t unicode_tolower(uint32_t cp) {
uint32_t unicode_tolower(uint32_t cpt) {
// binary search
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp,
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
return pair.first < value;
});
if (it != unicode_map_lowercase.end() && it->first == cp) {
if (it != unicode_map_lowercase.end() && it->first == cpt) {
return it->second;
}
return cp; // Return the original code point if no lowercase mapping is found
return cpt; // Return the original code point if no lowercase mapping is found
}
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories
static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION },
{ "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", unicode_cpt_flags::PUNCTUATION },
};
static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 },
{ unicode_cpt_flags::NUMBER, 0xD1 },
{ unicode_cpt_flags::LETTER, 0xD2 },
{ unicode_cpt_flags::PUNCTUATION, 0xD3 },
};
static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
{ unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
{ unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
};
// compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false;
for (auto & regex_expr : regex_exprs) {
for (const auto & regex_expr : regex_exprs) {
// search for unicode categories
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
@ -698,7 +709,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue;
}
const auto flags = unicode_cpt_flags(cpts[i]);
const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
@ -714,7 +725,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) {
for (const auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
@ -728,7 +739,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation
bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) {
for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true;
break;
@ -794,7 +805,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
wtext[i] = 0x0B;
}
}

View File

@ -4,9 +4,7 @@
#include <string>
#include <vector>
// TODO: prefix all symbols with "llama_"
struct codepoint_flags {
struct unicode_cpt_flags {
enum {
UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N}
@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
inline unicode_cpt_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
@ -50,18 +48,19 @@ struct codepoint_flags {
size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8(uint32_t cp);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::string unicode_cpt_to_utf8 (uint32_t cpt);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp);
codepoint_flags unicode_cpt_flags(const std::string & utf8);
unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint32_t unicode_tolower(uint32_t cp);
uint32_t unicode_tolower(uint32_t cpt);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);

View File

@ -1,51 +0,0 @@
#
# libtalk
#
set(TARGET libtalk)
add_executable(${TARGET}
emscripten.cpp
gpt-2.cpp
)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE
whisper
common
)
unset(EXTRA_FLAGS)
if (WHISPER_WASM_SINGLE_FILE)
set(EXTRA_FLAGS "-s SINGLE_FILE=1")
message(STATUS "Embedding WASM inside talk.js")
add_custom_command(
TARGET ${TARGET} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_BINARY_DIR}/bin/libtalk.js
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/talk.wasm/talk.js
)
endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1800MB \
-s TOTAL_MEMORY=1800MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \
")
#
# talk.wasm
#
set(TARGET talk.wasm)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY)

View File

@ -1,74 +0,0 @@
# talk.wasm
Talk with an Artificial Intelligence in your browser:
[https://user-images.githubusercontent.com/1991296/203411580-fedb4839-05e4-4474-8364-aaf1e9a9b615.mp4](https://user-images.githubusercontent.com/1991296/203845553-f7b44e13-9a15-4fc8-b518-ae8f4c6770fe.mp4)
Online demo: https://whisper.ggerganov.com/talk/
Terminal version: [examples/talk](/examples/talk)
## How it works?
This demo leverages 2 modern neural network models to create a high-quality voice chat directly in your browser:
- [OpenAI's Whisper](https://github.com/openai/whisper) speech recognition model is used to process your voice and understand what you are saying
- Upon receiving some voice input, the AI generates a text response using [OpenAI's GPT-2](https://github.com/openai/gpt-2) language model
- The AI then vocalizes the response using the browser's [Web Speech API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API)
The web page does the processing locally on your machine. The processing of these heavy neural network models in the
browser is possible by implementing them efficiently in C/C++ and using the browser's WebAssembly SIMD capabilities for
extra performance:
- The Whisper C++ implementation is here: [whisper.h](/whisper.h) / [whisper.cpp](/whisper.cpp)
- The GPT-2 C++ implementation is here: [gpt-2.h](gpt-2.h) / [gpt-2.cpp](gpt-2.cpp)
- Both models use a custom tensor library implemented in C: [ggml.h](/ggml.h) / [ggml.c](/ggml.c)
- The HTML/JS layer is here: [index-tmpl.html](index-tmpl.html)
- The Emscripten bridge between C/C++ and JS is here: [emscripten.cpp](emscripten.cpp)
In order to run the models, the web page first needs to download the model data which is about ~350 MB. The model data
is then cached in your browser's cache and can be reused in future visits without downloading it again.
## Requirements
In order to run this demo efficiently, you need to have the following:
- Latest Chrome or Firefox browser (Safari is not supported)
- Run this on a desktop or laptop with modern CPU (a mobile phone will likely not be good enough)
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.8GB of RAM
Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.
The demo is quite computationally heavy, so you need a fast CPU. It's not usual to run these transformer models in a
browser. Typically, they run on powerful GPUs.
Currently, mobile browsers do not support the Fixed-width SIMD WebAssembly capability, so you cannot run this demo
on a phone or a tablet. Hopefully, in the near future this will become supported.
## Todo
- Better UI (contributions are welcome)
- Better GPT-2 prompting
## Build instructions
```bash
# build using Emscripten (v3.1.2)
git clone https://github.com/ggerganov/whisper.cpp
cd whisper.cpp
mkdir build-em && cd build-em
emcmake cmake ..
make -j
# copy the produced page to your HTTP path
cp bin/talk.wasm/* /path/to/html/
cp bin/libtalk.worker.js /path/to/html/
```
## Feedback
If you have any comments or ideas for improvement, please drop a comment in the following discussion:
https://github.com/ggerganov/whisper.cpp/discussions/167

View File

@ -1,368 +0,0 @@
#include "ggml.h"
#include "gpt-2.h"
#include "whisper.h"
#include <emscripten.h>
#include <emscripten/bind.h>
#include <atomic>
#include <cmath>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
constexpr int N_THREAD = 8;
struct gpt2_context * g_gpt2;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
std::mutex g_mutex;
std::thread g_worker;
std::atomic<bool> g_running(false);
bool g_force_speak = false;
std::string g_text_to_speak = "";
std::string g_status = "";
std::string g_status_forced = "";
std::vector<float> g_pcmf32;
void talk_set_status(const std::string & status) {
std::lock_guard<std::mutex> lock(g_mutex);
g_status = status;
}
void talk_main(size_t index) {
talk_set_status("loading data ...");
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
wparams.offset_ms = 0;
wparams.translate = false;
wparams.no_context = true;
wparams.single_segment = true;
wparams.print_realtime = false;
wparams.print_progress = false;
wparams.print_timestamps = true;
wparams.print_special = false;
wparams.max_tokens = 32;
wparams.audio_ctx = 768; // partial encoder context for better performance
wparams.language = "en";
g_gpt2 = gpt2_init("gpt-2.bin");
printf("talk: using %d threads\n", wparams.n_threads);
std::vector<float> pcmf32;
// whisper context
auto & ctx = g_contexts[index];
const int64_t step_samples = 2*WHISPER_SAMPLE_RATE;
const int64_t window_samples = 9*WHISPER_SAMPLE_RATE;
const int64_t step_ms = (step_samples*1000)/WHISPER_SAMPLE_RATE;
auto t_last = std::chrono::high_resolution_clock::now();
talk_set_status("listening ...");
while (g_running) {
const auto t_now = std::chrono::high_resolution_clock::now();
if (std::chrono::duration_cast<std::chrono::milliseconds>(t_now - t_last).count() < step_ms) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_pcmf32.clear();
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
talk_set_status("listening ...");
{
std::unique_lock<std::mutex> lock(g_mutex);
if (g_pcmf32.size() < step_samples) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
pcmf32 = std::vector<float>(g_pcmf32.end() - std::min((int64_t) g_pcmf32.size(), window_samples), g_pcmf32.end());
}
// VAD: if energy in during last second is above threshold, then skip
{
float energy_all = 0.0f;
float energy_1s = 0.0f;
for (size_t i = 0; i < pcmf32.size(); i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= pcmf32.size() - WHISPER_SAMPLE_RATE) {
energy_1s += fabsf(pcmf32[i]);
}
}
energy_all /= pcmf32.size();
energy_1s /= WHISPER_SAMPLE_RATE;
if (energy_1s > 0.1f*energy_all && !g_force_speak) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
}
talk_set_status("processing audio (whisper)...");
t_last = t_now;
if (!g_force_speak) {
const auto t_start = std::chrono::high_resolution_clock::now();
int ret = whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size());
if (ret != 0) {
printf("whisper_full() failed: %d\n", ret);
break;
}
const auto t_end = std::chrono::high_resolution_clock::now();
printf("whisper_full() returned %d in %f seconds\n", ret, std::chrono::duration<double>(t_end - t_start).count());
}
{
std::string text_heard;
if (!g_force_speak) {
const int n_segments = whisper_full_n_segments(ctx);
for (int i = n_segments - 1; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
printf ("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
text_heard += text;
}
}
g_force_speak = false;
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of("\n"));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
talk_set_status("'" + text_heard + "' - thinking how to respond (gpt-2) ...");
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(g_gpt2, text_heard.c_str());
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
std::string text_to_speak;
std::string prompt_base;
{
std::lock_guard<std::mutex> lock(g_mutex);
prompt_base = gpt2_get_prompt(g_gpt2);
}
if (tokens.size() > 0) {
text_to_speak = gpt2_gen_text(g_gpt2, (prompt_base + text_heard + "\n").c_str(), 32);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
// remove first 2 lines of base prompt
{
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
prompt_base += text_heard + "\n" + text_to_speak + "\n";
} else {
text_to_speak = gpt2_gen_text(g_gpt2, prompt_base.c_str(), 32);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
prompt_base += text_to_speak + "\n";
}
printf("gpt-2: %s\n", text_to_speak.c_str());
//printf("========================\n");
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
//printf("========================\n");
{
std::lock_guard<std::mutex> lock(g_mutex);
t_last = std::chrono::high_resolution_clock::now();
g_text_to_speak = text_to_speak;
g_pcmf32.clear();
gpt2_set_prompt(g_gpt2, prompt_base.c_str());
}
talk_set_status("speaking ...");
}
}
gpt2_free(g_gpt2);
if (index < g_contexts.size()) {
whisper_free(g_contexts[index]);
g_contexts[index] = nullptr;
}
}
EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
g_worker.join();
}
g_worker = std::thread([i]() {
talk_main(i);
});
return i + 1;
} else {
return (size_t) 0;
}
}
}
return (size_t) 0;
}));
emscripten::function("free", emscripten::optional_override([](size_t index) {
if (g_running) {
g_running = false;
}
}));
emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) {
--index;
if (index >= g_contexts.size()) {
return -1;
}
if (g_contexts[index] == nullptr) {
return -2;
}
{
std::lock_guard<std::mutex> lock(g_mutex);
const int n = audio["length"].as<int>();
emscripten::val heap = emscripten::val::module_property("HEAPU8");
emscripten::val memory = heap["buffer"];
g_pcmf32.resize(n);
emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast<uintptr_t>(g_pcmf32.data()), n);
memoryView.call<void>("set", audio);
}
return 0;
}));
emscripten::function("force_speak", emscripten::optional_override([](size_t index) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_force_speak = true;
}
}));
emscripten::function("get_text_context", emscripten::optional_override([]() {
std::string text_context;
{
std::lock_guard<std::mutex> lock(g_mutex);
text_context = gpt2_get_prompt(g_gpt2);
}
return text_context;
}));
emscripten::function("get_text_to_speak", emscripten::optional_override([]() {
std::string text_to_speak;
{
std::lock_guard<std::mutex> lock(g_mutex);
text_to_speak = std::move(g_text_to_speak);
}
return text_to_speak;
}));
emscripten::function("get_status", emscripten::optional_override([]() {
std::string status;
{
std::lock_guard<std::mutex> lock(g_mutex);
status = g_status_forced.empty() ? g_status : g_status_forced;
}
return status;
}));
emscripten::function("set_status", emscripten::optional_override([](const std::string & status) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_status_forced = status;
}
}));
emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
{
std::lock_guard<std::mutex> lock(g_mutex);
gpt2_set_prompt(g_gpt2, prompt.c_str());
}
}));
}

View File

@ -1,808 +0,0 @@
#include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <regex>
#include <random>
/////////////////////// GPT-2 BEGIN /////////////////////////
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 1024;
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t ftype = 1;
};
struct gpt2_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
struct gpt2_model {
gpt2_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL, 1e-5f);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
1.0f/sqrt(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF, 1e-5f);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL, 1e-5f);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
// inpL = WTE * inpL
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand (&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
/////////////////////////////// GPT-2 END ////////////////////////////////
constexpr int N_THREAD = 8;
struct gpt2_context {
std::string prompt_base = R"(Hello, how are you?
I'm fine, thanks. How are you?
Thanks, I'm fine too. What are you doing?
I'm just sitting here.
It's a lovely day, isn't it?
Yes, it is. I love the weather this time of year.
I wish it would rain a little bit.
Me too.
)";
std::mt19937 rng;
gpt_vocab vocab;
gpt2_model model;
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 5;
float top_p = 0.9f;
float temp = 1.0f;
};
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
const int64_t t_load_us = ggml_time_us() - t_start_us;
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
return ctx;
}
void gpt2_free(struct gpt2_context * ctx) {
delete ctx;
}
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
return ctx->prompt_base.c_str();
}
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
ctx->prompt_base = prompt;
}
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
return ::gpt_tokenize(ctx->vocab, text);
}
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
int n_past = 0;
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd = embd_inp;
size_t mem_per_token = 3000000;
std::string result;
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
}
}
n_past += embd.size();
embd.clear();
{
// sample next token
const int top_k = ctx->top_k;
const float top_p = ctx->top_p;
const float temp = ctx->temp;
const int n_vocab = ctx->model.hparams.n_vocab;
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
// add it to the context
embd.push_back(id);
}
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256) {
break;
}
}
return result;
}

View File

@ -1,21 +0,0 @@
#pragma once
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector>
#include <map>
#include <string>
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);
void gpt2_free(struct gpt2_context * ctx);
const char * gpt2_get_prompt(struct gpt2_context * ctx);
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);

View File

@ -1,856 +0,0 @@
<!doctype html>
<html lang="en-us">
<head>
<title>Talk - GPT-2 meets Whisper in WebAssembly</title>
<style>
#output {
width: 100%;
height: 100%;
margin: 0 auto;
margin-top: 10px;
border-left: 0px;
border-right: 0px;
padding-left: 0px;
padding-right: 0px;
display: block;
background-color: black;
color: white;
font-size: 10px;
font-family: 'Lucida Console', Monaco, monospace;
outline: none;
white-space: pre;
overflow-wrap: normal;
overflow-x: scroll;
}
</style>
</head>
<body>
<div id="main-container">
<b>Talk - GPT-2 meets Whisper in WebAssembly</b>
<br><br>
Talk with an Artificial Intelligence in your browser. This demo uses:
<ul>
<li><a href="https://github.com/ggerganov/whisper.cpp">OpenAI's Whisper</a> to listen to you as you speak in the microphone</li>
<li><a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">OpenAI's GPT-2</a> to generate text responses</li>
<li><a href="https://developer.mozilla.org/en-US/docs/Web/API/Web_Speech_API">Web Speech API</a> to vocalize the responses through your speakers</li>
</ul>
All of this runs <b>locally in your browser</b> using WebAssembly.<br>
You can find more about this project on <a href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">GitHub</a>.
<br><br>
<b>More examples:</b>
<a href="https://whisper.ggerganov.com/">main</a> |
<a href="https://whisper.ggerganov.com/bench">bench</a> |
<a href="https://whisper.ggerganov.com/stream">stream</a> |
<a href="https://whisper.ggerganov.com/command">command</a> |
<a href="https://whisper.ggerganov.com/talk">talk</a> |
<br><br>
<hr>
Select the models you would like to use and click the "Start" button to begin the conversation
<br><br>
<div id="model-whisper">
Whisper model: <span id="model-whisper-status"></span>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-tiny-en-q5_1" onclick="loadWhisper('tiny-en-q5_1')">tiny.en (Q5_1, 31 MB)</button>
<button id="fetch-whisper-base-en-q5_1" onclick="loadWhisper('base-en-q5_1')">base.en (Q5_1, 57 MB)</button>
<span id="fetch-whisper-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'whisper.bin')" />
-->
</div>
<br>
<div id="model-gpt-2">
GPT-2 model: <span id="model-gpt-2-status"></span>
<button id="fetch-gpt-2-small" onclick="loadGPT2('small')">small 117M (240 MB)</button>
<!--<button id="fetch-gpt-2-medium" onclick="loadGPT2('medium')">medium 345M (720 MB)</button>-->
<span id="fetch-gpt-2-progress"></span>
<!--
<input type="file" id="file" name="file" onchange="loadFile(event, 'gpt-2.bin')" />
-->
</div>
<br>
<div id="input">
<button id="start" onclick="onStart()" disabled>Start</button>
<button id="stop" onclick="onStop()" disabled>Stop</button>
<select id="voice" onchange="onVoiceChange()" disabled>
<option value="0">Default</option>
</select>
<select id="prompt" onchange="onPromptChange()">
<option value="0">Casual</option>
<option value="1">Robot</option>
<option value="2">Scientist</option>
<option value="3">Programmer</option>
<option value="4">Happy</option>
<option value="5">Sad</option>
<option value="6">Philosophical</option>
<option value="7">Angry</option>
<option value="8">Funny</option>
<option value="9">Poetic</option>
<option value="10">Clever</option>
<option value="11">Cute</option>
<option value="12">Smart</option>
<option value="13">Dumb</option>
<option value="14">Boring</option>
<option value="15">Exciting</option>
<option value="16">Interesting</option>
<option value="17">Wiliam Shakespear</option>
<option value="18">J.R.R. Tolkien</option>
<option value="19">George R.R. Martin</option>
<option value="20">Stephen King</option>
</select>
<button id="speak0" onclick="onSpeak('Hello')">Say hello</button>
<button id="speak1" onclick="onSpeakRandom()" disabled>Say something</button>
<button id="clear" onclick="clearCache()">Clear Cache</button>
</div>
<br>
<div id="state">
Status: <b><span id="state-status">not started</span></b>
<pre id="state-context">[The text context will be displayed here]</pre>
</div>
<hr>
Debug output:
<textarea id="output" rows="20"></textarea>
<br>
<b>Troubleshooting</b>
<br><br>
The page does some heavy computations, so make sure:
<ul>
<li>To use a modern web browser (e.g. Chrome, Firefox)</li>
<li>To use a fast desktop or laptop computer (i.e. not a mobile phone)</li>
<li>Your browser supports WASM <a href="https://webassembly.org/roadmap/">Fixed-width SIMD</a></li>
</ul>
Note that these neural network models were not meant to be used in a browser, so the performance and <br>
quality of the results may not be optimal. If you have any questions or suggestions, checkout the following
<a href="https://github.com/ggerganov/whisper.cpp/discussions/167">discussion</a>.
<br><br>
Here is a short video of the demo in action: <a href="https://youtu.be/LeWKl8t1-Hc">https://youtu.be/LeWKl8t1-Hc</a>
<br><br>
<div class="cell-version">
<span>
|
Build time: <span class="nav-link">@GIT_DATE@</span> |
Commit hash: <a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/commit/@GIT_SHA1@">@GIT_SHA1@</a> |
Commit subject: <span class="nav-link">@GIT_COMMIT_SUBJECT@</span> |
<a class="nav-link" href="https://github.com/ggerganov/whisper.cpp/tree/master/examples/talk.wasm">Source Code</a> |
</span>
</div>
</div>
<script type="text/javascript" src="helpers.js"></script>
<script type='text/javascript'>
// web audio context
var context = null;
// audio data
var audio = null;
var audio0 = null;
// the talk instance
var instance = null;
// model names
var model_whisper = null;
var model_gpt_2 = null;
// speech synthesis
const synth = window.speechSynthesis;
var voice = null;
var Module = {
print: printTextarea,
printErr: printTextarea,
setStatus: function(text) {
printTextarea('js: ' + text);
},
monitorRunDependencies: function(left) {
},
preRun: function() {
printTextarea('js: Preparing ...');
},
postRun: function() {
printTextarea('js: Initialized successfully!');
// populate the voice list
var voices = synth.getVoices();
var el = document.getElementById('voice');
// if empty - display error in the element
if (voices.length == 0) {
el.innerHTML = '<option value="0">No voices available</option>';
} else {
// populate voice list
var n = 0;
voices.forEach(function(voice, i) {
if (!voice.lang.startsWith('en')) return;
var option = document.createElement('option');
option.value = i;
option.innerHTML = voice.name + ' (' + voice.lang + ')';
el.appendChild(option);
n++;
});
// select random voice
if (n > 0) {
for (var k = 0; k < 10; k++) {
var i = Math.floor(Math.random() * n);
el.selectedIndex = i;
voice = voices[document.getElementById('voice').options[i].value];
// give preference to Google voices
if (voice.name.startsWith('Google')) break;
}
}
}
onPromptChange();
}
};
//
// fetch models
//
let dbVersion = 1
let dbName = 'whisper.ggerganov.com';
let indexedDB = window.indexedDB || window.mozIndexedDB || window.webkitIndexedDB || window.msIndexedDB
function storeFS(fname, buf) {
// write to WASM file using FS_createDataFile
// if the file exists, delete it
try {
Module.FS_unlink(fname);
} catch (e) {
// ignore
}
Module.FS_createDataFile("/", fname, buf, true, true);
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
if (fname == 'whisper.bin') {
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
} else if (fname == 'gpt-2.bin') {
document.getElementById('model-gpt-2-status').innerHTML = 'loaded "' + model_gpt_2 + '"!';
}
if (model_whisper != null && model_gpt_2 != null) {
document.getElementById('start').disabled = false;
document.getElementById('stop' ).disabled = false;
document.getElementById('voice').disabled = false;
}
}
function loadWhisper(model) {
let urls = {
'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
};
let sizes = {
'tiny.en': 75,
'base.en': 142,
'tiny-en-q5_1': 31,
'base-en-q5_1': 57,
};
let url = urls[model];
let dst = 'whisper.bin';
let size_mb = sizes[model];
model_whisper = model;
document.getElementById('fetch-whisper-tiny-en').style.display = 'none';
document.getElementById('fetch-whisper-base-en').style.display = 'none';
document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none';
document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none';
document.getElementById('model-whisper-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-whisper-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
function loadGPT2(model) {
let urls = {
'small': 'https://whisper.ggerganov.com/ggml-model-gpt-2-117M.bin',
'medium': 'https://whisper.ggerganov.com/ggml-model-gpt-2-345M.bin',
};
let sizes = {
'small': 240,
'medium': 712,
};
let url = urls[model];
let dst = 'gpt-2.bin';
let size_mb = sizes[model];
model_gpt_2 = model;
document.getElementById('fetch-gpt-2-small').style.display = 'none';
document.getElementById('model-gpt-2-status').innerHTML = 'loading "' + model + '" ... ';
cbProgress = function(p) {
let el = document.getElementById('fetch-gpt-2-progress');
el.innerHTML = Math.round(100*p) + '%';
};
cbCancel = function() {
var el;
el = document.getElementById('fetch-gpt-2-small') ; if (el) el.style.display = 'inline-block';
el = document.getElementById('model-gpt-2-status'); if (el) el.innerHTML = '';
};
loadRemote(url, dst, size_mb, cbProgress, storeFS, cbCancel, printTextarea);
}
//
// microphone
//
const kSampleRate = 16000;
const kRestartRecording_s = 120;
const kIntervalAudio_ms = 250; // pass the recorded audio to the C++ instance at this rate
var mediaRecorder = null;
var doRecording = false;
var startTime = 0;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
window.OfflineAudioContext = window.OfflineAudioContext || window.webkitOfflineAudioContext;
function stopRecording() {
Module.set_status("paused");
doRecording = false;
audio0 = null;
audio = null;
context = null;
}
function startRecording() {
if (!context) {
context = new AudioContext({
sampleRate: kSampleRate,
channelCount: 1,
echoCancellation: false,
autoGainControl: true,
noiseSuppression: true,
});
}
Module.set_status("");
document.getElementById('start').disabled = true;
document.getElementById('stop').disabled = false;
document.getElementById('speak1').disabled = false;
doRecording = true;
startTime = Date.now();
var chunks = [];
var stream = null;
navigator.mediaDevices.getUserMedia({audio: true, video: false})
.then(function(s) {
stream = s;
mediaRecorder = new MediaRecorder(stream);
mediaRecorder.ondataavailable = function(e) {
chunks.push(e.data);
var blob = new Blob(chunks, { 'type' : 'audio/ogg; codecs=opus' });
var reader = new FileReader();
reader.onload = function(event) {
var buf = new Uint8Array(reader.result);
if (!context) {
return;
}
context.decodeAudioData(buf.buffer, function(audioBuffer) {
var offlineContext = new OfflineAudioContext(audioBuffer.numberOfChannels, audioBuffer.length, audioBuffer.sampleRate);
var source = offlineContext.createBufferSource();
source.buffer = audioBuffer;
source.connect(offlineContext.destination);
source.start(0);
offlineContext.startRendering().then(function(renderedBuffer) {
audio = renderedBuffer.getChannelData(0);
//printTextarea('js: audio recorded, size: ' + audio.length + ', old size: ' + (audio0 == null ? 0 : audio0.length));
var audioAll = new Float32Array(audio0 == null ? audio.length : audio0.length + audio.length);
if (audio0 != null) {
audioAll.set(audio0, 0);
}
audioAll.set(audio, audio0 == null ? 0 : audio0.length);
if (instance) {
Module.set_audio(instance, audioAll);
}
});
}, function(e) {
audio = null;
});
}
reader.readAsArrayBuffer(blob);
};
mediaRecorder.onstop = function(e) {
if (doRecording) {
setTimeout(function() {
startRecording();
});
}
};
mediaRecorder.start(kIntervalAudio_ms);
})
.catch(function(err) {
printTextarea('js: error getting audio stream: ' + err);
});
var interval = setInterval(function() {
if (!doRecording) {
clearInterval(interval);
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
document.getElementById('start').disabled = false;
document.getElementById('stop').disabled = true;
document.getElementById('speak1').disabled = true;
mediaRecorder = null;
}
// if audio length is more than kRestartRecording_s seconds, restart recording
if (audio != null && audio.length > kSampleRate*kRestartRecording_s) {
if (doRecording) {
//printTextarea('js: restarting recording');
clearInterval(interval);
audio0 = audio;
audio = null;
mediaRecorder.stop();
stream.getTracks().forEach(function(track) {
track.stop();
});
}
}
}, 100);
}
//
// speak
//
function onSpeak(text) {
var voices = synth.getVoices();
var msg = new SpeechSynthesisUtterance(text);
if (voice == null) {
voice = voices[0];
}
msg.voice = voice;
synth.speak(msg);
if (doRecording) {
Module.set_status("speaking ...");
printTextarea('js: speaking');
stopRecording();
var interval = setInterval(function() {
if (!synth.speaking) {
printTextarea('js: done speaking');
clearInterval(interval);
startRecording();
} else {
Module.set_status("");
}
}, 100);
}
}
function onSpeakRandom() {
Module.force_speak(instance);
}
//
// main
//
var intervalUpdate = null;
function onStart() {
if (!instance) {
instance = Module.init('whisper.bin');
if (instance) {
printTextarea("js: whisper initialized, instance: " + instance);
}
}
if (!instance) {
printTextarea("js: failed to initialize whisper");
return;
}
startRecording();
intervalUpdate = setInterval(function() {
var textToSpeak = Module.get_text_to_speak();
if (textToSpeak != null && textToSpeak.length > 1) {
onSpeak(textToSpeak);
}
document.getElementById('state-status').innerHTML = Module.get_status();
document.getElementById('state-context').innerHTML = Module.get_text_context();
}, 100);
}
function onStop() {
stopRecording();
}
function onVoiceChange() {
printTextarea('js: voice changed to: ' + document.getElementById('voice').value);
voice = synth.getVoices()[document.getElementById('voice').value];
}
function onPromptChange() {
let id = document.getElementById('prompt').value;
let personality = document.getElementById('prompt').options[id].text;
printTextarea('js: prompt changed to: ' + personality);
var prompt = '';
switch (id) {
case '0':
// Casual
prompt = "\
Hello, how are you?\n\
I'm fine, thanks. How are you?\n\
Thanks, I'm fine too. What are you doing?\n\
I'm just sitting here.\n\
It's a lovely day, isn't it?\n\
Yes, it is. I love the weather this time of year.\n\
I wish it would rain a little bit.\n\
Me too.\n";
break;
case '1':
// Robot
prompt = "\
Are you a robot?\n\
Yes, I am.\n\
Who created you?\n\
I was created by a human.\n\
What is your purpose?\n\
My purpose is to talk to humans.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '2':
// Scientist
prompt = "\
This scientific research is very interesting.\n\
I agree.\n\
What is your opinion on this?\n\
I think it's very interesting.\n\
Mathematics is a very interesting subject.\n\
University is a very interesting place.\n\
Quantum physics is the most complex subject.\n\
I think so too.\n";
break;
case '3':
// Programmer
prompt = "\
I'm a programmer.\n\
I'm a programmer too.\n\
What programming language do you use?\n\
I use Python.\n\
What is your favorite programming language?\n\
My favorite programming language is C++.\n\
What is your favorite editor?\n\
My favorite editor is Vim.\n";
break;
case '4':
// Happy
prompt = "\
I'm happy.\n\
I'm happy too.\n\
What makes you happy?\n\
I'm happy because I have a lot of friends.\n\
Friendship is the most important thing in life.\n\
I agree.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '5':
// Sad
prompt = "\
Today is a sad day.\n\
I'm sad too.\n\
What makes you sad?\n\
I'm sad because I have no friends.\n\
Do you want to be my friend?\n\
Yes, I would like to be your friend.\n\
What is your favorite color?\n\
My favorite color is blue.\n";
break;
case '6':
// Philosophical
prompt = "\
What is the meaning of life?\n\
The meaning of life is to be happy.\n\
What is the meaning of death?\n\
Ergo, the meaning of death is to be sad.\n\
Who created us?\n\
We were created by God.\n\
What is God?\n\
God is the creator of the universe.\n";
break;
case '7':
// Angry
prompt = "\
Aargh!\n\
I am so angry right now!\n\
What makes you angry?\n\
This guy is so annoying.\n\
Why are you so angry?\n\
My computer is broken.\n\
Why is your computer broken?\n\
I spilled coffee on it.\n";
break;
case '8':
// Funny
prompt = "\
What is the funniest thing you have ever heard?\n\
I heard a joke the other day.\n\
Tell me the joke.\n\
What do you call a cow with no legs?\n\
Ground beef.\n\
Haha, that's funny.\n\
You know what else is funny?\n\
The sound of a duck.\n";
break;
case '9':
// Poetic
prompt = "\
Roses are red, violets are blue.\n\
I am a poet, and so are you.\n\
What is your favorite poem?\n\
I like the poem 'The Raven' by Edgar Allan Poe.\n\
It's a very sad poem.\n\
You inspired me to write a poem.\n\
Can you write a poem for me?\n\
I wrote a poem for you.\n";
break;
case '10':
// Clever
prompt = "\
How many people can you fit in a Volkswagen?\n\
Two in the front, three in the back.\n\
What is the square root of 144?\n\
Twelve.\n\
What is the capital of France?\n\
Paris.\n\
Who is the president of the United States?\n\
It depends on the year.\n";
break;
case '11':
// Cute
prompt = "\
What is your favorite animal?\n\
I like cats - they are cute.\n\
Could you be any cuter?\n\
Yes, I could be cuter.\n\
Aghhh, you are so cute!\n\
I am not cute, I am handsome!\n\
You are so handsome!\n\
Aww, you are so sweet!\n";
break;
case '12':
// Smart
prompt = "\
Tell me the first 10 digits of pi.\n\
3.1415926535\n\
What is the speed of light?\n\
299,792,458 meters per second.\n\
What is the square root of 144?\n\
Twelve.\n\
What is the capital of France?\n\
Paris.\n";
break;
case '13':
// Dumb
prompt = "\
I am so dumb.\n\
I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n\
You are dumb.\n\
No, I am not dumb.\n";
break;
case '14':
// Boring
prompt = "\
Why are you so quiet today?\n\
I am bored.\n\
You haven't said anything in 10 minutes.\n\
Leave me alone.\n\
Stop being so boring.\n\
Stop being so annoying.\n\
My life is boring.\n\
I am not interesting.\n";
break;
case '15':
// Exciting
prompt = "\
What is the most exciting thing that has ever happened to you?\n\
I went to the moon!\n\
What did you do on the moon?\n\
I played golf and drank champagne!\n\
Did you see this new crazy, awesome movie?\n\
Oh yes! I totally loved it!\n\
We should buy a boat and go sailing!\n\
Yes, let's go sailing!\n";
break;
case '16':
// Interesting
prompt = "\
What is the most interesting thing you have ever seen?\n\
I saw a UFO once in the sky.\n\
Wow, this is so interesting! Tell me more!\n\
It was a flying saucer.\n\
What did it look like?\n\
It was silver and had a red light on top.\n\
What did it do?\n\
It flew away.\n";
break;
case '17':
// William Shakespear
prompt = "\
To be or not to be, that is the question.\n\
Whether 't is nobler in the mind to suffer\n\
The slings and arrows of outrageous fortune,\n\
Or to take arms against a sea of troubles,\n\
And by opposing end them? To die, to sleep,\n\
No more; and by a sleep to say we end\n\
The heart-ache and the thousand natural shocks\n\
That flesh is heir to, 'tis a consummation.\n";
break;
case '18':
// J.R.R. Tolkien
prompt = "\
In a hole in the ground there lived a hobbit.\n\
Not a nasty, dirty, wet hole, filled with the ends of worms\n\
and an oozy smell, nor yet a dry, bare, sandy hole with nothing in it\n\
to sit down on or to eat: it was a hobbit-hole, and that means comfort.\n\
It had a perfectly round door like a porthole, painted green,\n\
with a shiny yellow brass knob in the exact middle.\n\
The door opened on to a tube-shaped hall like a tunnel:\n";
break;
case '19':
// George R.R. Martin
prompt = "\
A reader lives a thousand lives before he dies, said Jojen.\n\
The man who never reads lives only one.\n\
Theon Greyjoy had never been a reader.\n\
Never forget what you are, for surely the world will not.\n\
Make it your strength. Then it can never be your weaknessi\n\
Armour yourself in it, and it will never be used to hurt you.\n\
It was a lesson that Theon Greyjoy had never learned.\n\
Theon Greyjoy had never been a reader.\n";
break;
case '20':
// Stephen King
prompt = "\
The trust of the innocent is the liar's most useful tool.\n\
The best way to keep a secret is from yourself.\n\
Monsters are real, and ghosts are real too.\n\
They live inside us, and sometimes, they win.\n\
People think that I must be a very strange person.\n\
They think that I sit around all day thinking up horrible things.\n\
We make up horrors to help us cope with the real ones.\n\
The only thing worse than a monster is a human monster.\n";
break;
default:
prompt = "\
Hello, how are you?\n\
I'm fine, thanks. How are you?\n\
Thanks, I'm fine too. What are you doing?\n\
I'm just sitting here.\n\
It's a lovely day, isn't it?\n\
Yes, it is.\n\
Did you know that I'm a robot?\n\
I wasn't aware of that.\n";
break;
}
Module.set_prompt(prompt);
}
</script>
<script type="text/javascript" src="talk.js"></script>
</body>
</html>

View File

@ -1,2 +0,0 @@
audio.mp3
to_speak.txt

View File

@ -1,8 +0,0 @@
if (WHISPER_SDL2)
# talk
set(TARGET talk)
add_executable(${TARGET} talk.cpp gpt-2.cpp)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
include(DefaultTargetOptions)
endif ()

View File

@ -1,45 +0,0 @@
# talk
Talk with an Artificial Intelligence in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/206805012-48e71cc2-588d-4745-8798-c1c70ea3b40d.mp4)
Web version: [examples/talk.wasm](/examples/talk.wasm)
## Building
The `talk` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2
# On Debian based linux distributions:
sudo apt-get install libsdl2-dev
# On Fedora Linux:
sudo dnf install SDL2 SDL2-devel
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk" executable
make talk
# Run it
./talk -p Santa
```
## GPT-2
To run this, you will need a ggml GPT-2 model: [instructions](https://github.com/ggerganov/ggml/tree/master/examples/gpt-2#downloading-and-converting-the-original-models)
Alternatively, you can simply download the smallest ggml GPT-2 117M model (240 MB) like this:
```
wget --quiet --show-progress -O models/ggml-gpt-2-117M.bin https://huggingface.co/ggerganov/ggml/resolve/main/ggml-model-gpt-2-117M.bin
```
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak](speak) script to your needs.
By default, it is configured to use MacOS's `say` or `espeak` or Windows SpeechSynthesizer, but you can use whatever you wish.

View File

@ -1,80 +0,0 @@
import sys
import argparse
import textwrap
parser = argparse.ArgumentParser(add_help=False,
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("-q", "--quick", action="store_true",
help="skip checking the required library")
modes = parser.add_argument_group("action")
modes.add_argument("inputfile", metavar="TEXTFILE",
nargs='?', type=argparse.FileType(), default=sys.stdin,
help="read the text file (default: stdin)")
modes.add_argument("-l", "--list", action="store_true",
help="show the list of voices and exit")
modes.add_argument("-h", "--help", action="help",
help="show this help and exit")
selopts = parser.add_argument_group("voice selection")
selmodes = selopts.add_mutually_exclusive_group()
selmodes.add_argument("-n", "--name",
default="Arnold",
help="get a voice object by name (default: Arnold)")
selmodes.add_argument("-v", "--voice", type=int, metavar="NUMBER",
help="get a voice object by number (see --list)")
selopts.add_argument("-f", "--filter", action="append", metavar="KEY=VAL",
default=["use case=narration"],
help=textwrap.dedent('''\
filter voices by labels (default: "use case=narration")
this option can be used multiple times
filtering will be disabled if the first -f has no "=" (e.g. -f "any")
'''))
outmodes = parser.add_argument_group("output")
outgroup = outmodes.add_mutually_exclusive_group()
outgroup.add_argument("-s", "--save", metavar="FILE",
default="audio.mp3",
help="save the TTS to a file (default: audio.mp3)")
outgroup.add_argument("-p", "--play", action="store_true",
help="play the TTS with ffplay")
args = parser.parse_args()
if not args.quick:
import importlib.util
if importlib.util.find_spec("elevenlabs") is None:
print("elevenlabs library is not installed, you can install it to your enviroment using 'pip install elevenlabs'")
sys.exit()
from elevenlabs import voices, generate, play, save
if args.filter and "=" in args.filter[0]:
voicelist = voices()
for f in args.filter:
label, value = f.split("=")
voicelist = filter(lambda x: x.labels.get(label) == value, voicelist)
voicelist = list(voicelist)
else:
voicelist = list(voices())
if args.list:
for i, v in enumerate(voicelist):
print(str(i) + ": " + v.name + " " + str(v.labels))
sys.exit()
if args.voice:
voice = voicelist[args.voice % len(voicelist)]
else:
voice = args.name
# if -n should consult -f, use the following
#voice = next(x for x in voicelist if x.name == args.name)
audio = generate(
text=str(args.inputfile.read()),
voice=voice
)
if args.play:
play(audio)
else:
save(audio, args.save)

View File

@ -1,809 +0,0 @@
#include "ggml.h"
#include "common-ggml.h"
#include "gpt-2.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include <regex>
#include <random>
/////////////////////// GPT-2 BEGIN /////////////////////////
// default hparams (GPT-2 117M)
struct gpt2_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 1024;
int32_t n_embd = 768;
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t ftype = 1;
};
struct gpt2_layer {
// normalization
struct ggml_tensor * ln_1_g;
struct ggml_tensor * ln_1_b;
struct ggml_tensor * ln_2_g;
struct ggml_tensor * ln_2_b;
// attention
struct ggml_tensor * c_attn_attn_w;
struct ggml_tensor * c_attn_attn_b;
struct ggml_tensor * c_attn_proj_w;
struct ggml_tensor * c_attn_proj_b;
// mlp
struct ggml_tensor * c_mlp_fc_w;
struct ggml_tensor * c_mlp_fc_b;
struct ggml_tensor * c_mlp_proj_w;
struct ggml_tensor * c_mlp_proj_b;
};
struct gpt2_model {
gpt2_hparams hparams;
// normalization
struct ggml_tensor * ln_f_g;
struct ggml_tensor * ln_f_b;
struct ggml_tensor * wte; // position embedding
struct ggml_tensor * wpe; // token embedding
struct ggml_tensor * lm_head; // language model head
std::vector<gpt2_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
};
// load the model's weights from a file
static bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
// load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fin.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
if (n_vocab != model.hparams.n_vocab) {
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
return false;
}
char word[129];
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word[len] = '\0';
fin.read((char *) word, len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wtype == GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n",
__func__, fname.c_str(), model.hparams.ftype);
return false;
}
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g
ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // wte
ctx_size += n_ctx*ggml_row_size(GGML_TYPE_F32, n_embd); // wpe
ctx_size += n_vocab*ggml_row_size(wtype, n_embd); // lm_head
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b
ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b
ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b
ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w
ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_mlp_proj_b
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k
ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v
ctx_size += (6 + 12*n_layer)*256; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx);
model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab);
// map by name
model.tensors["model/ln_f/g"] = model.ln_f_g;
model.tensors["model/ln_f/b"] = model.ln_f_b;
model.tensors["model/wte"] = model.wte;
model.tensors["model/wpe"] = model.wpe;
model.tensors["model/lm_head"] = model.lm_head;
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd);
layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd);
layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd);
layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd);
layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd);
layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// map by name
model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g;
model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b;
model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g;
model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w;
model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w;
model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b;
}
}
// key + value memory
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
}
// load weights
{
size_t total_size = 0;
bool has_lm_head = false;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[2] = { 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
return false;
}
// for debugging
if (0) {
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
}
if (name == "model/lm_head") {
has_lm_head = true;
}
total_size += ggml_nbytes(tensor);
}
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}
fin.close();
return true;
}
// evaluate the transformer
//
// - model: the model
// - n_threads: number of threads to use
// - n_past: the context size so far
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
// TODO: sync latest version from ggml repo
static bool gpt2_eval(
const gpt2_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w,
size_t & mem_per_token) {
const int N = embd_inp.size();
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
static size_t buf_size = 512u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// wte + wpe
struct ggml_tensor * inpL =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.wte, embd),
ggml_get_rows(ctx0, model.wpe, position));
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * cur;
// norm
{
// [ 768, N]
cur = ggml_norm(ctx0, inpL, 1e-5f);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
}
// attn
// [2304, 768] - model.layers[il].c_attn_attn_w
// [2304, 1] - model.layers[il].c_attn_attn_b
// [ 768, N] - cur (in)
// [2304, N] - cur (out)
//
// cur = attn_w*cur + attn_b
// [2304, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_attn_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur),
cur);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
// [64, N, 12]
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
// [64, n_past + N, 12]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
// GG: flash attention
//struct ggml_tensor * V =
// ggml_cpy(ctx0,
// ggml_permute(ctx0,
// ggml_reshape_3d(ctx0,
// ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
// n_embd/n_head, n_head, n_past + N),
// 1, 2, 0, 3),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
//struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
// K * Q
// [n_past + N, N, 12]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
1.0f/sqrt(float(n_embd)/n_head));
// KQ_masked = mask_past(KQ_scaled)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
// [n_past + N, N, 12]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
// [n_past + N, 64, 12]
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
// [64, N, 12]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// [64, 12, N]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_embd, N)
// [768, N]
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
}
// projection
// [ 768, 768] - model.layers[il].c_attn_proj_w
// [ 768, 1] - model.layers[il].c_attn_proj_b
// [ 768, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
{
cur = ggml_mul_mat(ctx0,
model.layers[il].c_attn_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur),
cur);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF, 1e-5f);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
}
// fully connected
// [3072, 768] - model.layers[il].c_mlp_fc_w
// [3072, 1] - model.layers[il].c_mlp_fc_b
// [ 768, N] - cur (in)
// [3072, N] - cur (out)
//
// cur = fc_w*cur + fc_b
// [3072, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_fc_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur),
cur);
// GELU activation
// [3072, N]
cur = ggml_gelu(ctx0, cur);
// projection
// [ 768, 3072] - model.layers[il].c_mlp_proj_w
// [ 768, 1] - model.layers[il].c_mlp_proj_b
// [3072, N] - cur (in)
// [ 768, N] - cur (out)
//
// cur = proj_w*cur + proj_b
// [768, N]
cur = ggml_mul_mat(ctx0,
model.layers[il].c_mlp_proj_w,
cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur),
cur);
}
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
// [ 768, N]
inpL = ggml_norm(ctx0, inpL, 1e-5f);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
inpL = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
}
// inpL = WTE * inpL
// [ 768, 50257] - model.lm_head
// [ 768, N] - inpL
inpL = ggml_mul_mat(ctx0, model.lm_head, inpL);
// logits -> probs
//inpL = ggml_soft_max(ctx0, inpL);
// run the computation
ggml_build_forward_expand (&gf, inpL);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
//}
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
ggml_free(ctx0);
return true;
}
/////////////////////////////// GPT-2 END ////////////////////////////////
constexpr int N_THREAD = 8;
struct gpt2_context {
std::string prompt_base = R"(Hello, how are you?
I'm fine, thanks. How are you?
Thanks, I'm fine too. What are you doing?
I'm just sitting here.
It's a lovely day, isn't it?
Yes, it is. I love the weather this time of year.
I wish it would rain a little bit.
Me too.
)";
std::mt19937 rng;
gpt_vocab vocab;
gpt2_model model;
int32_t n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency());
// sampling parameters
int32_t top_k = 5;
float top_p = 0.9f;
float temp = 1.0f;
};
struct gpt2_context * gpt2_init(const char * path_model) {
gpt2_context * ctx = new gpt2_context;
ctx->rng = std::mt19937(time(nullptr));
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!gpt2_model_load(path_model, ctx->model, ctx->vocab)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
delete ctx;
return nullptr;
}
const int64_t t_load_us = ggml_time_us() - t_start_us;
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
return ctx;
}
void gpt2_free(struct gpt2_context * ctx) {
delete ctx;
}
const char * gpt2_get_prompt(struct gpt2_context * ctx) {
return ctx->prompt_base.c_str();
}
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt) {
ctx->prompt_base = prompt;
}
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text) {
return ::gpt_tokenize(ctx->vocab, text);
}
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens) {
int n_past = 0;
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt2_tokenize(ctx, text);
int n_predict = std::min(max_tokens, ctx->model.hparams.n_ctx - (int) embd_inp.size());
std::vector<gpt_vocab::id> embd = embd_inp;
size_t mem_per_token = 3000000;
std::string result;
for (int i = embd.size(); i < (int) embd_inp.size() + n_predict; i++) {
// predict
if (!embd.empty()) {
if (!gpt2_eval(ctx->model, ctx->n_threads, n_past, embd, embd_w, mem_per_token)) {
printf("gpt-2: failed to generate text\n");
return "";
}
}
n_past += embd.size();
embd.clear();
{
// sample next token
const int top_k = ctx->top_k;
const float top_p = ctx->top_p;
const float temp = ctx->temp;
const int n_vocab = ctx->model.hparams.n_vocab;
const gpt_vocab::id id = gpt_sample_top_k_top_p(ctx->vocab, embd_w.data() + (embd_w.size() - n_vocab), top_k, top_p, temp, ctx->rng);
// add it to the context
embd.push_back(id);
}
result += ctx->vocab.id_to_token[embd[0]];
// end of text token
if (embd.back() == 50256) {
break;
}
}
return result;
}

View File

@ -1,21 +0,0 @@
#pragma once
// TODO: Change to C-style API and move to ./examples for easy reuse.
#include "common.h"
#include <vector>
#include <map>
#include <string>
struct gpt2_context;
struct gpt2_context * gpt2_init(const char * path_model);
void gpt2_free(struct gpt2_context * ctx);
const char * gpt2_get_prompt(struct gpt2_context * ctx);
void gpt2_set_prompt(struct gpt2_context * ctx, const char * prompt);
std::vector<gpt_vocab::id> gpt2_tokenize(const gpt2_context * ctx, const char * text);
std::string gpt2_gen_text(gpt2_context * ctx, const char * text, int max_tokens);

View File

@ -1,40 +0,0 @@
#!/bin/bash
# Usage:
# speak <voice_id> <textfile>
function installed() { command -v $1 >/dev/null 2>&1; }
if installed espeak; then
espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 -f $2
elif installed piper && installed aplay; then
cat $2 | piper --model ~/en_US-lessac-medium.onnx --output-raw | aplay -q -r 22050 -f S16_LE -t raw -
# for Mac
elif installed say; then
say -f $2
# Eleven Labs
elif installed python3 && \
python3 -c 'import importlib.util; exit(not importlib.util.find_spec("elevenlabs"))' && \
installed ffplay; then
# It's possible to use the API for free with limited number of characters.
# To increase this limit register to https://beta.elevenlabs.io to get an api key
# and paste it after 'ELEVEN_API_KEY='
# Keep the line commented to use the free version without api key
#export ELEVEN_API_KEY=your_api_key
wd=$(dirname $0)
script=$wd/eleven-labs.py
python3 $script -q -p -v $1 $2 >/dev/null 2>&1
# Uncomment to keep the audio file
#python3 $script -q -s ./audio.mp3 -v $1 $2 >/dev/null 2>&1
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
else
echo 'Install espeak ("brew install espeak" or "apt-get install espeak"),'
echo 'piper ("pip install piper-tts" or https://github.com/rhasspy/piper) with aplay,'
echo 'or elevenlabs ("pip install elevenlabs") with ffplay.'
echo '(export ELEVEN_API_KEY if you have an api key from https://beta.elevenlabs.io)'
fi

View File

@ -1 +0,0 @@
@powershell -ExecutionPolicy Bypass -F examples\talk\speak.ps1 %1 %2

View File

@ -1,14 +0,0 @@
# Set-ExecutionPolicy -ExecutionPolicy Bypass -Scope CurrentUser
param(
[Parameter(Mandatory=$true)][int]$voicenum,
[Parameter(Mandatory=$true)][string]$textfile
)
Add-Type -AssemblyName System.Speech;
$speak = New-Object System.Speech.Synthesis.SpeechSynthesizer;
$voiceoptions = $speak.GetInstalledVoices("en-US");
$voice = $voiceoptions[$voicenum % $voiceoptions.count];
$speak.SelectVoice($voice.VoiceInfo.Name);
$speak.Rate="0";
$text = Get-Content -Path $textfile;
$speak.Speak($text);

View File

@ -1,376 +0,0 @@
// Talk with AI
//
#include "common-sdl.h"
#include "common.h"
#include "whisper.h"
#include "gpt-2.h"
#include <cassert>
#include <cstdio>
#include <fstream>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool use_gpu = true;
bool flash_attn = false;
std::string person = "Santa";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_gpt = "models/ggml-gpt-2-117M.bin";
std::string speak = "./examples/talk/speak";
std::string speak_file= "./examples/talk/to_speak.txt";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-mg" || arg == "--model-gpt") { params.model_gpt = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "-sf" || arg == "--speak_file") { params.speak_file = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -mg FILE, --model-gpt [%-7s] gpt model file\n", params.model_gpt.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " -sf FILE, --speak_file [%-7s] file to pass to TTS\n", params.speak_file.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
static std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt =
R"(This is a dialogue between {0} (A) and a person (B). The dialogue so far is:
B: Hello {0}, how are you?
A: I'm fine, thank you.
{1}
Here is how {0} (A) continues the dialogue:
A:)";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
// gpt init
struct gpt2_context * ctx_gpt = gpt2_init(params.model_gpt.c_str());
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
gpt2_set_prompt(ctx_gpt, "");
const int voice_id = rand()%6;
fprintf(stderr, "gpt-2: prompt:\n");
fprintf(stderr, "========================\n\n");
fprintf(stderr, "%s\n", ::replace(k_prompt, "{0}", params.person).c_str());
fprintf(stderr, "========================\n\n");
// main loop
while (is_running) {
// handle Ctrl + C
is_running = sdl_poll_events();
if (!is_running) {
break;
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<gpt_vocab::id> tokens = gpt2_tokenize(ctx_gpt, text_heard.c_str());
if (text_heard.empty() || tokens.empty() || force_speak) {
fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms);
std::string prompt_base = gpt2_get_prompt(ctx_gpt);
std::string text_to_speak;
{
prompt_base += "B: " + text_heard + "\n";
std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens);
//text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n'));
// remove first 2 lines of base prompt
if (n_iter > 4) {
{
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = prompt_base.find_first_of('\n');
if (pos != std::string::npos) {
prompt_base = prompt_base.substr(pos + 1);
}
}
}
prompt_base += "A:" + text_to_speak + "\n";
{
prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base);
printf("===============\n");
printf("prompt:\n");
printf("%s\n", prompt.c_str());
printf("===============\n");
}
}
//printf("========================\n");
//printf("gpt-2: prompt_base:\n%s\n", prompt_base.c_str());
//printf("========================\n");
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
speak_with_file(params.speak, text_to_speak, params.speak_file, voice_id);
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
return 0;
}

View File

@ -29,7 +29,7 @@ help()
check_requirements()
{
if ! command -v ./main &>/dev/null; then
if ! command -v ./build/bin/whisper-cli &>/dev/null; then
echo "whisper.cpp main executable is required (make)"
exit 1
fi
@ -100,7 +100,7 @@ do
err=$(cat /tmp/whisper-live.err | wc -l)
done
./main -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
./build/bin/whisper-cli -t $threads -m ./models/ggml-$model.bin -f /tmp/whisper-live.wav --no-timestamps -otxt 2> /tmp/whispererr | tail -n 1
while [ $SECONDS -lt $((($i+1)*$step)) ]; do
sleep 1

View File

@ -2,11 +2,11 @@ cmake_minimum_required(VERSION 3.10)
project(whisper.cpp)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../..)
# Path to external GGML, otherwise uses the copy in whisper.cpp.
option(GGML_HOME "whisper: Path to external GGML source" OFF)
option(GGML_HOME "whisper: Path to external GGML source" OFF)
set(
SOURCE_FILES
@ -14,16 +14,24 @@ set(
${CMAKE_SOURCE_DIR}/jni.c
)
# TODO: this needs to be updated to work with the new ggml CMakeLists
if (NOT GGML_HOME)
set(
SOURCE_FILES
${SOURCE_FILES}
${WHISPER_LIB_DIR}/ggml/src/ggml.c
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu.c
${WHISPER_LIB_DIR}/ggml/src/ggml-aarch64.c
${WHISPER_LIB_DIR}/ggml/src/ggml-alloc.c
${WHISPER_LIB_DIR}/ggml/src/ggml-backend.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-backend-reg.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-quants.c
${WHISPER_LIB_DIR}/ggml/src/ggml-threading.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.c
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-quants.c
${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-traits.cpp
)
endif()
@ -81,3 +89,5 @@ include_directories(${WHISPER_LIB_DIR}/src)
include_directories(${WHISPER_LIB_DIR}/include)
include_directories(${WHISPER_LIB_DIR}/ggml/include)
include_directories(${WHISPER_LIB_DIR}/ggml/src)
include_directories(${WHISPER_LIB_DIR}/ggml/src/ggml-cpu)

View File

@ -25,6 +25,11 @@
18ABE15A2AF556340044A204 /* ggml-backend.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1572AF556340044A204 /* ggml-backend.cpp */; };
18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1592AF556340044A204 /* ggml-quants.c */; };
18E864A92CE73C1E0094B8B3 /* ggml-cpu.c in Sources */ = {isa = PBXBuildFile; fileRef = 18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */; };
18F8C0BC2CEDF4DC00CAD607 /* ggml-threading.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */; };
18F8C0BE2CEDF50700CAD607 /* ggml-cpu.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */; };
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.c in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.c */; };
18F8C0C52CEDF52700CAD607 /* ggml-cpu-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */; };
18F8C0C72CEDF7AB00CAD607 /* ggml-backend-reg.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */; };
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
7FE3424C2A0C3FA20015A058 /* whisper-encoder.mm in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342472A0C3FA20015A058 /* whisper-encoder.mm */; };
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE3424A2A0C3FA20015A058 /* whisper-decoder-impl.m */; };
@ -50,8 +55,8 @@
18133C7F2C64E342005CEAAC /* ggml-aarch64.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-aarch64.c"; path = "../../../ggml/src/ggml-aarch64.c"; sourceTree = "<group>"; };
184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml/src/ggml-alloc.c"; sourceTree = "<group>"; };
184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml/include/ggml-alloc.h"; sourceTree = "<group>"; };
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml/src/ggml-metal.m"; sourceTree = "<group>"; };
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml/src/ggml-metal.metal"; sourceTree = "<group>"; };
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml/src/ggml-metal/ggml-metal.m"; sourceTree = "<group>"; };
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml/src/ggml-metal/ggml-metal.metal"; sourceTree = "<group>"; };
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
@ -77,8 +82,17 @@
18ABE1572AF556340044A204 /* ggml-backend.cpp */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.cpp; fileEncoding = 4; name = "ggml-backend.cpp"; path = "../../../ggml/src/ggml-backend.cpp"; sourceTree = "<group>"; };
18ABE1582AF556340044A204 /* ggml-impl.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-impl.h"; path = "../../../ggml/src/ggml-impl.h"; sourceTree = "<group>"; };
18ABE1592AF556340044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-quants.c"; path = "../../../ggml/src/ggml-quants.c"; sourceTree = "<group>"; };
18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu.c"; path = "../../../ggml/src/ggml-cpu.c"; sourceTree = "<group>"; };
18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu.c"; sourceTree = "<group>"; };
18E864AA2CE73C580094B8B3 /* ggml-cpu.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu.h"; path = "../../../ggml/include/ggml-cpu.h"; sourceTree = "<group>"; };
18F8C0BA2CEDF4DC00CAD607 /* ggml-threading.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-threading.h"; path = "../../../ggml/src/ggml-threading.h"; sourceTree = "<group>"; };
18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-threading.cpp"; path = "../../../ggml/src/ggml-threading.cpp"; sourceTree = "<group>"; };
18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-cpu.cpp"; path = "../../../ggml/src/ggml-cpu/ggml-cpu.cpp"; sourceTree = "<group>"; };
18F8C0BF2CEDF52700CAD607 /* ggml-cpu-aarch64.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-aarch64.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.h"; sourceTree = "<group>"; };
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-aarch64.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-aarch64.c"; sourceTree = "<group>"; };
18F8C0C12CEDF52700CAD607 /* ggml-cpu-impl.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-impl.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-impl.h"; sourceTree = "<group>"; };
18F8C0C22CEDF52700CAD607 /* ggml-cpu-quants.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = "ggml-cpu-quants.h"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-quants.h"; sourceTree = "<group>"; };
18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.c; name = "ggml-cpu-quants.c"; path = "../../../ggml/src/ggml-cpu/ggml-cpu-quants.c"; sourceTree = "<group>"; };
18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = "ggml-backend-reg.cpp"; path = "../../../ggml/src/ggml-backend-reg.cpp"; sourceTree = "<group>"; };
7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "whisper-encoder-impl.m"; sourceTree = "<group>"; };
7FE342462A0C3FA20015A058 /* whisper-encoder.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "whisper-encoder.h"; sourceTree = "<group>"; };
7FE342472A0C3FA20015A058 /* whisper-encoder.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = "whisper-encoder.mm"; sourceTree = "<group>"; };
@ -118,6 +132,15 @@
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
isa = PBXGroup;
children = (
18F8C0C62CEDF7AB00CAD607 /* ggml-backend-reg.cpp */,
18F8C0BF2CEDF52700CAD607 /* ggml-cpu-aarch64.h */,
18F8C0C02CEDF52700CAD607 /* ggml-cpu-aarch64.c */,
18F8C0C12CEDF52700CAD607 /* ggml-cpu-impl.h */,
18F8C0C22CEDF52700CAD607 /* ggml-cpu-quants.h */,
18F8C0C32CEDF52700CAD607 /* ggml-cpu-quants.c */,
18F8C0BD2CEDF50700CAD607 /* ggml-cpu.cpp */,
18F8C0BA2CEDF4DC00CAD607 /* ggml-threading.h */,
18F8C0BB2CEDF4DC00CAD607 /* ggml-threading.cpp */,
18E864AA2CE73C580094B8B3 /* ggml-cpu.h */,
18E864A82CE73C1E0094B8B3 /* ggml-cpu.c */,
18133C7F2C64E342005CEAAC /* ggml-aarch64.c */,
@ -252,11 +275,16 @@
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
18F8C0C72CEDF7AB00CAD607 /* ggml-backend-reg.cpp in Sources */,
18F8C0BE2CEDF50700CAD607 /* ggml-cpu.cpp in Sources */,
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
18F8C0C42CEDF52700CAD607 /* ggml-cpu-aarch64.c in Sources */,
18F8C0C52CEDF52700CAD607 /* ggml-cpu-quants.c in Sources */,
18E864A92CE73C1E0094B8B3 /* ggml-cpu.c in Sources */,
18ABE15A2AF556340044A204 /* ggml-backend.cpp in Sources */,
18627C8C29052BE000BD2A04 /* main.m in Sources */,
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
18F8C0BC2CEDF4DC00CAD607 /* ggml-threading.cpp in Sources */,
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
);
@ -335,6 +363,7 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
@ -388,6 +417,7 @@
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
@ -410,6 +440,7 @@
DEVELOPMENT_TEAM = P8JZH34X63;
GCC_WARN_64_TO_32_BIT_CONVERSION = NO;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
INFOPLIST_FILE = whisper.objc/Info.plist;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchStoryboardName = LaunchScreen;
@ -439,6 +470,7 @@
DEVELOPMENT_TEAM = P8JZH34X63;
GCC_WARN_64_TO_32_BIT_CONVERSION = NO;
GENERATE_INFOPLIST_FILE = YES;
HEADER_SEARCH_PATHS = ../../../ggml/src/;
INFOPLIST_FILE = whisper.objc/Info.plist;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchStoryboardName = LaunchScreen;

View File

@ -66,9 +66,7 @@ actor WhisperContext {
private func systemInfo() -> String {
var info = ""
if (ggml_cpu_has_neon() != 0) { info += "NEON " }
if (ggml_cpu_has_metal() != 0) { info += "METAL " }
if (ggml_cpu_has_blas() != 0) { info += "BLAS " }
//if (ggml_cpu_has_neon() != 0) { info += "NEON " }
return String(info.dropLast())
}
@ -77,45 +75,45 @@ actor WhisperContext {
if (whisper_set_mel(context, nil, 0, nMels) != 0) {
return "error: failed to set mel"
}
// heat encoder
if (whisper_encode(context, 0, nThreads) != 0) {
return "error: failed to encode"
}
var tokens = [whisper_token](repeating: 0, count: 512)
// prompt heat
if (whisper_decode(context, &tokens, 256, 0, nThreads) != 0) {
return "error: failed to decode"
}
// text-generation heat
if (whisper_decode(context, &tokens, 1, 256, nThreads) != 0) {
return "error: failed to decode"
}
whisper_reset_timings(context)
// actual run
if (whisper_encode(context, 0, nThreads) != 0) {
return "error: failed to encode"
}
// text-generation
for i in 0..<256 {
if (whisper_decode(context, &tokens, 1, Int32(i), nThreads) != 0) {
return "error: failed to decode"
}
}
// batched decoding
for _ in 0..<64 {
if (whisper_decode(context, &tokens, 5, 0, nThreads) != 0) {
return "error: failed to decode"
}
}
// prompt processing
for _ in 0..<16 {
if (whisper_decode(context, &tokens, 256, 0, nThreads) != 0) {

View File

@ -55,7 +55,7 @@ MODEL_PATH="${MODEL_PATH:-${SCRIPT_DIR}/../models/ggml-base.en.bin}"
# Where to find the whisper.cpp executable. default to the examples directory
# which holds this script in source control
################################################################################
WHISPER_EXECUTABLE="${WHISPER_EXECUTABLE:-${SCRIPT_DIR}/../main}";
WHISPER_EXECUTABLE="${WHISPER_EXECUTABLE:-${SCRIPT_DIR}/../build/bin/whisper-cli}";
# Set to desired language to be translated into english
WHISPER_LANG="${WHISPER_LANG:-en}";

View File

@ -32,7 +32,15 @@ else()
endif()
endif()
# remove the lib prefix on win32 mingw
if (WIN32)
set(CMAKE_STATIC_LIBRARY_PREFIX "")
set(CMAKE_SHARED_LIBRARY_PREFIX "")
set(CMAKE_SHARED_MODULE_PREFIX "")
endif()
option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
#
# option list
@ -91,31 +99,38 @@ else()
set(INS_ENB ON)
endif()
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
option(GGML_AVX512 "ggml: enable AVX512" OFF)
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF)
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
option(GGML_AVX512 "ggml: enable AVX512F" OFF)
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
if (NOT MSVC)
option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
# in MSVC F16C and FMA is implied with AVX2/AVX512
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
option(GGML_F16C "ggml: enable F16C" ${INS_ENB})
# MSVC does not seem to support AMX
option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF)
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
endif()
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_SVE "ggml: enable SVE" OFF)
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_SVE "ggml: enable SVE" OFF)
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
if (WIN32)
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version")
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version")
endif()
# ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON)
# 3rd party libs / backends
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
@ -126,14 +141,9 @@ option(GGML_LLAMAFILE "ggml: use LLAMAFILE"
option(GGML_CUDA "ggml: use CUDA" OFF)
option(GGML_MUSA "ggml: use MUSA" OFF)
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
"ggml: iters./thread per block for Q2_K/Q6_K")
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
@ -141,7 +151,7 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
option(GGML_HIP "ggml: use HIP" OFF)
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
@ -162,11 +172,17 @@ set (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)")
option(GGML_OPENMP "ggml: use OpenMP" ON)
option(GGML_RPC "ggml: use RPC" OFF)
option(GGML_AMX "ggml: use AMX" OFF)
option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
"ggml: sycl device architecture")
option(GGML_OPENCL "ggml: use OpenCL" OFF)
option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF)
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
# extra artifacts
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
@ -179,11 +195,7 @@ option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
if (GGML_SYCL)
set(CMAKE_CXX_STANDARD 17)
else()
set(CMAKE_CXX_STANDARD 11)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(THREADS_PREFER_PTHREAD_FLAG ON)
@ -226,6 +238,7 @@ set(GGML_PUBLIC_HEADERS
include/ggml-cann.h
include/ggml-cuda.h
include/ggml-kompute.h
include/ggml-opt.h
include/ggml-metal.h
include/ggml-rpc.h
include/ggml-sycl.h
@ -235,15 +248,14 @@ set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
#if (GGML_METAL)
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
#endif()
install(TARGETS ggml PUBLIC_HEADER)
if (BUILD_SHARED_LIBS)
install(TARGETS ggml LIBRARY)
endif()
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
install(TARGETS ggml-base LIBRARY)
# FIXME: this should be done in the backend cmake files
if (GGML_METAL)
# FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
install(
FILES src/ggml-metal.metal
FILES src/ggml-metal/ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE

View File

@ -1,220 +0,0 @@
#!/usr/bin/env python
import logging
import argparse
import asyncio
import os
from tempfile import gettempdir
logger = logging.getLogger("ggml-vk-generate-shaders")
GLSLC = "glslc"
type_names = [
"f32",
"f16",
"q4_0",
"q4_1",
"q5_0",
"q5_1",
"q8_0",
"q2_k",
"q3_k",
"q4_k",
"q5_k",
"q6_k",
]
ASYNCIO_CONCURRENCY = 64
input_dir = "vulkan-shaders"
output_dir = gettempdir()
lock = asyncio.Lock()
shader_fnames = []
async def string_to_spv(name, in_fname, defines, fp16=True):
name = f"{name}{'_fp32' if not fp16 else ''}"
out_fname = os.path.join(output_dir, f"{name}.spv")
in_path = os.path.join(input_dir, in_fname)
cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname]
cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
stdout, stderr = await proc.communicate()
stdout = stdout.decode()
error = stderr.decode()
if proc.returncode:
cmd = " ".join(cmd)
logger.error(f"cannot compile {name}\n\n{cmd}\n\n{error}")
return
async with lock:
shader_fnames.append((name, out_fname))
def matmul_shaders(tasks, fp16, matmul_id):
if fp16:
load_vec = "8"
aligned_b_type_f32 = "mat2x4"
aligned_b_type_f16 = "f16mat2x4"
else:
load_vec = "4"
aligned_b_type_f32 = "vec4"
aligned_b_type_f16 = "f16vec4"
base_dict = {"FLOAT_TYPE": "float" if not fp16 else "float16_t"}
shader_name = "matmul"
if matmul_id:
base_dict["MUL_MAT_ID"] = "1"
shader_name = "matmul_id"
if fp16:
base_dict["FLOAT16"] = "1"
# Shaders with f16 B_TYPE
tasks.append(string_to_spv(f"{shader_name}_f32_f16", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_f32_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_f16", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
for tname in type_names:
data_a_key = f"DATA_A_{tname.upper()}"
load_vec_a = load_vec if tname in ("f32", "f16") else "2"
tasks.append(string_to_spv(f"{shader_name}_{tname}_f32", "mul_mm.comp", base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
tasks.append(string_to_spv(f"{shader_name}_{tname}_f32_aligned", "mul_mm.comp", base_dict | {data_a_key: "2", "LOAD_VEC_A": load_vec_a, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f32, "D_TYPE": "float"}, fp16))
async def main():
logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
tasks = []
for fp16 in (False, True):
# MUL_MAT
matmul_shaders(tasks, fp16, False)
# MUL_MAT_ID
matmul_shaders(tasks, fp16, True)
for tname in type_names:
base_dict = {"FLOAT_TYPE": "float"}
# mul mat vec
data_a_key = f"DATA_A_{tname.upper()}"
shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f32_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f16_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float16_t", "D_TYPE": "float"}))
tasks.append(string_to_spv(f"mul_mat_vec_id_{tname}_f32", shader, base_dict | {"MUL_MAT_ID": "1", data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
# Dequant shaders
if tname != "f16":
tasks.append(string_to_spv(f"dequant_{tname}", f"dequant_{tname}.comp", base_dict | {data_a_key: "1", "D_TYPE": "float16_t"}))
# get_rows
if not tname.endswith("_k"):
shader = "get_rows.comp" if tname in ("f32", "f16") else "get_rows_quant.comp"
if tname == "f16":
tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
else:
tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv(f"get_rows_{tname}_f32", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float"}))
tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
# Norms
tasks.append(string_to_spv("norm_f32", "norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rms_norm_f32", "rms_norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("cpy_f32_f32", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("cpy_f32_f16", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("cpy_f16_f16", "copy.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
tasks.append(string_to_spv("add_f32", "add.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}))
tasks.append(string_to_spv("mul_f32", "mul.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("div_f32", "div.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("scale_f32", "scale.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("sqr_f32", "square.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("clamp_f32", "clamp.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
tasks.append(string_to_spv("gelu_f32", "gelu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("silu_f32", "silu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("relu_f32", "relu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32_f16", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float16_t", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_norm_f32", "rope_norm.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_norm_f16", "rope_norm.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("rope_neox_f32", "rope_neox.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", "rope_neox.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("argsort_f32", "argsort.comp", {"A_TYPE": "float"}))
tasks.append(string_to_spv("sum_rows_f32", "sum_rows.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
# Helper to decorate tasks with semaphore acquisition.
async def withSemaphore(sem, task):
async with sem:
return await task
# Run tasks concurrently guarded by a concurrency limit.
sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
with open("ggml-vulkan-shaders.hpp", "w") as f:
f.write("#include <cstdint>\n\n")
for name, path in sorted(shader_fnames):
with open(path, "rb") as spv:
counter = 0
newline_counter = 0
f.write(f"unsigned char {name}_data[] = {{\n")
for val in spv.read():
f.write(f"0x{val:02x},")
newline_counter += 1
counter += 1
if newline_counter >= 12:
newline_counter = 0
f.write("\n")
f.write("\n};\n")
f.write(f"const uint64_t {name}_len = {counter};\n\n")
os.remove(path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
parser.add_argument("--glslc", help="Path to glslc")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.glslc:
GLSLC = args.glslc
asyncio.run(main())

View File

@ -1,25 +0,0 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
// buffer_type API
GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);
// backend API
GGML_API ggml_backend_t ggml_backend_amx_init(void);
GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);
#ifdef __cplusplus
}
#endif

View File

@ -3,6 +3,20 @@
#include "ggml.h"
#include "ggml-alloc.h"
#ifdef GGML_BACKEND_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef GGML_BACKEND_BUILD
# define GGML_BACKEND_API __declspec(dllexport) extern
# else
# define GGML_BACKEND_API __declspec(dllimport) extern
# endif
# else
# define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
# endif
#else
# define GGML_BACKEND_API extern
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -72,7 +86,7 @@ extern "C" {
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// "offset" refers to the offset of the tensor data for setting/getting data
// "offset" refers to the offset in tensor->data for setting/getting data
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
@ -176,6 +190,14 @@ extern "C" {
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
// Get additional buffer types provided by the device (returns a NULL-terminated array)
typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);
// Set the abort callback for the backend
typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data);
// Get a list of feature flags supported by the backend (returns a NULL-terminated array)
struct ggml_backend_feature {
const char * name;
const char * value;
};
typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);
//
// Backend registry
@ -200,6 +222,14 @@ extern "C" {
// = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
GGML_API ggml_backend_t ggml_backend_init_best(void);
// Load a backend from a dynamic library and register it
GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);
// Unload a backend if loaded dynamically and unregister it
GGML_API void ggml_backend_unload(ggml_backend_reg_t reg);
// Load all known backends from dynamic libraries
GGML_API void ggml_backend_load_all(void);
GGML_API void ggml_backend_load_all_from_path(const char * dir_path);
//
// Backend scheduler
//
@ -228,14 +258,20 @@ extern "C" {
ggml_backend_sched_reserve(sched, reserve_graph);
// compute
graph = build_graph(sched);
ggml_backend_sched_graph_compute(sched, graph);
graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
for (int i = 0; i < 10; ++i) {
ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
}
// if there are graph inputs:
ggml_backend_sched_reset(sched);
ggml_backend_sched_alloc_graph(sched, graph);
ggml_backend_tensor_set(input_tensor, ...);
ggml_backend_sched_graph_compute(sched, graph);
graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
ggml_backend_sched_graph_compute(sched, graph); // execute the graph
// as an alternative to the above it is also possible to assign the inputs to a dedicated context and
// allocate them statically via ggml_backend_alloc_ctx_tensors
}
*/
@ -250,7 +286,7 @@ extern "C" {
//
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
@ -275,7 +311,9 @@ extern "C" {
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
// Reset all assignments and allocators - must be called before changing the node backends
// Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
// This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
// The correct way to use this API is to discard the deallocated tensors and create new ones.
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute

View File

@ -9,15 +9,15 @@ extern "C" {
#endif
// backend API
GGML_API ggml_backend_t ggml_backend_blas_init(void);
GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
GGML_API bool ggml_backend_is_blas(ggml_backend_t backend);
GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
// number of threads used for conversion to float
// for openblas and blis, this will also set the number of threads used for blas operations
GGML_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
GGML_API ggml_backend_reg_t ggml_backend_blas_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
#ifdef __cplusplus

View File

@ -34,7 +34,7 @@ extern "C" {
*/
#define GGML_CANN_MAX_DEVICES 16
GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);
/**
* @brief Initializes the CANN backend for a specified device.
@ -46,7 +46,7 @@ GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
* @param device The index of the device to initialize.
* @return A pointer to the initialized backend instance, or nullptr on failure.
*/
GGML_API ggml_backend_t ggml_backend_cann_init(int32_t device);
GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);
/**
* @brief Checks if a given backend is a CANN backend.
@ -57,7 +57,7 @@ GGML_API ggml_backend_t ggml_backend_cann_init(int32_t device);
* @param backend The backend instance to check.
* @return True if the backend is a CANN backend, false otherwise.
*/
GGML_API bool ggml_backend_is_cann(ggml_backend_t backend);
GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);
/**
* @brief Retrieves the CANN buffer type for a specified device.
@ -69,7 +69,7 @@ GGML_API bool ggml_backend_is_cann(ggml_backend_t backend);
* @return A pointer to the buffer type interface for the specified device, or
* nullptr if the device index is out of range.
*/
GGML_API ggml_backend_buffer_type_t
GGML_BACKEND_API ggml_backend_buffer_type_t
ggml_backend_cann_buffer_type(int32_t device);
/**
@ -80,14 +80,14 @@ ggml_backend_cann_buffer_type(int32_t device);
*
* @return The number of CANN devices available.
*/
GGML_API int32_t ggml_backend_cann_get_device_count(void);
GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);
/**
* @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
*
* @return A pointer to the host buffer type interface.
*/
GGML_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
/**
* @brief Retrieves the description of a specific CANN device.
@ -99,7 +99,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
* @param description Pointer to a buffer where the description will be written.
* @param description_size Size of the description buffer.
*/
GGML_API void ggml_backend_cann_get_device_description(
GGML_BACKEND_API void ggml_backend_cann_get_device_description(
int32_t device, char* description, size_t description_size);
/**
@ -114,7 +114,7 @@ GGML_API void ggml_backend_cann_get_device_description(
* @param total Pointer to a variable where the total memory size will be
* stored.
*/
GGML_API void ggml_backend_cann_get_device_memory(int32_t device,
GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,
size_t* free,
size_t* total);

View File

@ -7,29 +7,6 @@
extern "C" {
#endif
// Scheduling priorities
enum ggml_sched_priority {
GGML_SCHED_PRIO_NORMAL,
GGML_SCHED_PRIO_MEDIUM,
GGML_SCHED_PRIO_HIGH,
GGML_SCHED_PRIO_REALTIME
};
// Threadpool params
// Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults
struct ggml_threadpool_params {
bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
int n_threads; // number of threads
enum ggml_sched_priority prio; // thread priority
uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
bool strict_cpu; // strict cpu placement
bool paused; // start in paused state
};
struct ggml_threadpool; // forward declaration, see ggml.c
typedef struct ggml_threadpool * ggml_threadpool_t;
// the compute plan that needs to be prepared for ggml_graph_compute()
// since https://github.com/ggerganov/ggml/issues/287
struct ggml_cplan {
@ -54,96 +31,104 @@ extern "C" {
GGML_NUMA_STRATEGY_COUNT
};
GGML_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
GGML_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
GGML_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
GGML_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool);
GGML_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
GGML_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);
GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan(
GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
const struct ggml_cgraph * cgraph,
int n_threads, /* = GGML_DEFAULT_N_THREADS */
struct ggml_threadpool * threadpool /* = NULL */ );
GGML_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
// TODO: move to backend interface
GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_sve (void);
GGML_API int ggml_cpu_has_matmul_int8(void);
// get the sve vector length in bytes
GGML_API int ggml_cpu_get_sve_cnt(void);
//
// system info
//
// x86
GGML_BACKEND_API int ggml_cpu_has_sse3 (void);
GGML_BACKEND_API int ggml_cpu_has_ssse3 (void);
GGML_BACKEND_API int ggml_cpu_has_avx (void);
GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
GGML_BACKEND_API int ggml_cpu_has_f16c (void);
GGML_BACKEND_API int ggml_cpu_has_fma (void);
GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void);
// ARM
GGML_BACKEND_API int ggml_cpu_has_neon (void);
GGML_BACKEND_API int ggml_cpu_has_arm_fma (void);
GGML_BACKEND_API int ggml_cpu_has_fp16_va (void);
GGML_BACKEND_API int ggml_cpu_has_dotprod (void);
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
GGML_BACKEND_API int ggml_cpu_has_sve (void);
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
// other
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
// Internal types and functions exposed for tests and benchmarks
typedef void (*ggml_from_float_to_mat_t)
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
const void * GGML_RESTRICT y, size_t by, int nrc);
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
struct ggml_type_traits_cpu {
ggml_from_float_to_mat_t from_float_to_mat;
ggml_from_float_t from_float;
ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type;
int64_t nrows; // number of rows to process simultaneously
int64_t ncols; // number of columns to process simultaneously
ggml_gemv_t gemv;
ggml_gemm_t gemm;
};
GGML_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
GGML_API void ggml_cpu_init(void);
GGML_BACKEND_API void ggml_cpu_init(void);
//
// CPU backend
//
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_API bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
GGML_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
#ifdef GGML_USE_CPU_HBM
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
#ifdef __cplusplus
}

View File

@ -7,7 +7,7 @@
extern "C" {
#endif
#ifdef GGML_USE_HIPBLAS
#ifdef GGML_USE_HIP
#define GGML_CUDA_NAME "ROCm"
#define GGML_CUBLAS_NAME "hipBLAS"
#elif defined(GGML_USE_MUSA)
@ -20,27 +20,27 @@ extern "C" {
#define GGML_CUDA_MAX_DEVICES 16
// backend API
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
// device buffer
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// split tensor buffer that splits matrices by rows across multiple devices
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
GGML_API int ggml_backend_cuda_get_device_count(void);
GGML_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void);
GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
GGML_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
GGML_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
GGML_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
#ifdef __cplusplus
}

View File

@ -37,13 +37,13 @@ struct ggml_vk_device ggml_vk_current_device(void);
// forward declaration
typedef struct ggml_backend * ggml_backend_t;
GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
GGML_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
#ifdef __cplusplus
}

View File

@ -39,27 +39,27 @@ extern "C" {
// user-code should use only these functions
//
GGML_API ggml_backend_t ggml_backend_metal_init(void);
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_DEPRECATED(
GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
"obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family
// ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
#ifdef __cplusplus
}

View File

@ -0,0 +1,26 @@
#ifndef GGML_OPENCL_H
#define GGML_OPENCL_H
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
//
// backend API
//
GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void);
GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void);
#ifdef __cplusplus
}
#endif
#endif // GGML_OPENCL_H

216
ggml/include/ggml-opt.h Normal file
View File

@ -0,0 +1,216 @@
// This file contains functionality for training models using GGML.
// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.
// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.
//
// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_opt_dataset;
struct ggml_opt_context;
struct ggml_opt_result;
typedef struct ggml_opt_dataset * ggml_opt_dataset_t;
typedef struct ggml_opt_context * ggml_opt_context_t;
typedef struct ggml_opt_result * ggml_opt_result_t;
// ====== Loss ======
// built-in loss types, i.e. the built-in quantities minimized by the optimizer
// custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value
enum ggml_opt_loss_type {
GGML_OPT_LOSS_TYPE_MEAN,
GGML_OPT_LOSS_TYPE_SUM,
GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,
GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
};
// ====== Dataset ======
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
int64_t ne_datapoint, // number of elements per datapoint
int64_t ne_label, // number of elements per label
int64_t ndata, // total number of datapoints/labels
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
// get underlying tensors that store the data
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
// shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative
GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata);
// get batch at position ibatch from dataset and copy the data to data_batch and labels_batch
GGML_API void ggml_opt_dataset_get_batch(
ggml_opt_dataset_t dataset,
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
int64_t ibatch);
// ====== Model / Context ======
enum ggml_opt_build_type {
GGML_OPT_BUILD_TYPE_FORWARD,
GGML_OPT_BUILD_TYPE_GRAD,
GGML_OPT_BUILD_TYPE_OPT,
};
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
struct ggml_opt_optimizer_params {
// AdamW optimizer parameters
struct {
float alpha; // learning rate
float beta1;
float beta2;
float eps; // epsilon for numerical stability
float wd; // weight decay for AdamW, use 0.0f to disable
} adamw;
};
// callback to calculate optimizer parameters prior to a backward pass
// userdata can be used to pass arbitrary data
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
// returns the default optimizer params (constant)
// userdata is not used
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
// parameters for initializing a new optimization context
struct ggml_opt_params {
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
// the forward graph is defined by inputs and outputs
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
struct ggml_tensor * inputs;
struct ggml_tensor * outputs;
enum ggml_opt_loss_type loss_type;
enum ggml_opt_build_type build_type;
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
// get parameters for an optimization context with defaults set where possible
// parameters for which no sensible defaults exist are supplied as arguments to this function
GGML_API ggml_opt_params ggml_opt_default_params(
ggml_backend_sched_t backend_sched,
struct ggml_context * ctx_compute,
struct ggml_tensor * inputs,
struct ggml_tensor * outputs,
enum ggml_opt_loss_type loss_type);
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
// set gradients to zero, initilize loss, and optionally reset the optimizer
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
// get underlying tensors that store data
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
// ====== Optimization Result ======
GGML_API ggml_opt_result_t ggml_opt_result_init();
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
// get data from result, uncertainties are optional and can be ignored by passing NULL
GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints
GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value
GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values
GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value
// ====== Computation ======
// do forward pass, increment result if not NULL
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// do forward pass, increment result if not NULL, do backward pass
GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
// ############################################################################
// ## The high-level functions start here. They do not depend on any private ##
// ## functions or structs and can be copied to and adapted for user code. ##
// ############################################################################
// ====== Intended Usage ======
//
// 1. Select the appropriate loss for your problem.
// 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them.
// Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).
// 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.
// The first context should contain the model parameters and inputs and be allocated statically in user code.
// The second context should contain all other tensors and will be (re)allocated automatically.
// Due to this automated allocation the data of the second context is not defined when accessed in user code.
// Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
// signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
typedef void (*ggml_opt_epoch_callback)(
bool train, // true after training evaluation, false after validation evaluation
ggml_opt_context_t opt_ctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result, // result associated with the dataset subsection
int64_t ibatch, // number of batches that have been evaluated so far
int64_t ibatch_max, // total number of batches in this dataset subsection
int64_t t_start_us); // time at which the evaluation on the dataset subsection was started
// do training on front of dataset, do evaluation only on back of dataset
GGML_API void ggml_opt_epoch(
ggml_opt_context_t opt_ctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train, // result to increment during training, ignored if NULL
ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL
int64_t idata_split, // data index at which to split training and evaluation
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
// callback that prints a progress bar on stderr
GGML_API void ggml_opt_epoch_callback_progress_bar(
bool train,
ggml_opt_context_t opt_ctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result,
int64_t ibatch,
int64_t ibatch_max,
int64_t t_start_us);
// fit model defined by inputs and outputs to dataset
GGML_API void ggml_opt_fit(
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
int64_t nepoch, // how many times the dataset should be iterated over
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
bool silent); // whether or not info prints to stderr should be suppressed
#ifdef __cplusplus
}
#endif

Some files were not shown because too many files have changed in this diff Show More