mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-25 01:19:10 +00:00
Compare commits
62 Commits
grammar-de
...
metal-and-
Author | SHA1 | Date | |
---|---|---|---|
3ac0558009 | |||
a1664574fe | |||
bfcb2a2ab9 | |||
0d5e4cdc36 | |||
2b4160af29 | |||
f36554382a | |||
c46167f8c5 | |||
af947cb72e | |||
e81c67a125 | |||
f408c64564 | |||
d863f725a1 | |||
d37f56e7a9 | |||
23277d21ce | |||
ecb23fb1eb | |||
8e8daa8451 | |||
16db4da3f1 | |||
257d7942af | |||
181bb8cb28 | |||
796f84cd95 | |||
77f4bf49c8 | |||
b6f09669a2 | |||
254b687239 | |||
b19888cfb4 | |||
905c944143 | |||
3074a7ff14 | |||
ec9a7db74c | |||
cd476375b4 | |||
9fdd415367 | |||
79a88057bd | |||
fbc9ddc582 | |||
3b9979a373 | |||
de94c783ee | |||
3fec2119e6 | |||
d3b2dd4955 | |||
4845b9ed09 | |||
2770d46ef5 | |||
4d9acc60c3 | |||
06d1d2836b | |||
9a78b72246 | |||
794e8fe0ea | |||
fa672b46e6 | |||
af6f67b251 | |||
bed5ad69dd | |||
949ab6328d | |||
fbc3f8033e | |||
9b14418863 | |||
6ddc727fac | |||
acb5278cc8 | |||
0839209cab | |||
b39809668a | |||
3e9edc6845 | |||
bfc73f1fa2 | |||
f00c9bba33 | |||
b55b505690 | |||
2818de21ff | |||
aed5d40607 | |||
afa5477d1c | |||
01fcd42431 | |||
f990610776 | |||
64cb45fd79 | |||
ace6c12ec6 | |||
cac75be05b |
6
.github/workflows/build.yml
vendored
6
.github/workflows/build.yml
vendored
@ -428,15 +428,15 @@ jobs:
|
||||
|
||||
- name: Publish package
|
||||
if: ${{ github.ref == 'refs/heads/master' }}
|
||||
uses: gradle/gradle-build-action@v2
|
||||
uses: gradle/gradle-build-action@v2.4.2
|
||||
with:
|
||||
arguments: publish
|
||||
build-root-directory: bindings/java
|
||||
env:
|
||||
MAVEN_USERNAME: ${{ secrets.JIRA_USER }}
|
||||
MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }}
|
||||
# MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }}
|
||||
# MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }}
|
||||
PGP_SECRET: ${{ secrets.GPG_PRIVATE_KEY }}
|
||||
PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }}
|
||||
|
||||
quantize:
|
||||
runs-on: ubuntu-latest
|
||||
|
113
CMakeLists.txt
113
CMakeLists.txt
@ -1,4 +1,4 @@
|
||||
cmake_minimum_required (VERSION 3.0)
|
||||
cmake_minimum_required (VERSION 3.5)
|
||||
|
||||
project(whisper.cpp VERSION 1.4.2)
|
||||
|
||||
@ -35,6 +35,12 @@ endif()
|
||||
|
||||
# options
|
||||
|
||||
if (APPLE)
|
||||
set(WHISPER_METAL_DEFAULT ON)
|
||||
else()
|
||||
set(WHISPER_METAL_DEFAULT OFF)
|
||||
endif()
|
||||
|
||||
option(BUILD_SHARED_LIBS "whisper: build shared libs" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||
|
||||
option(WHISPER_ALL_WARNINGS "whisper: enable all compiler warnings" ON)
|
||||
@ -58,6 +64,8 @@ option(WHISPER_OPENVINO "whisper: support for OpenVINO" OFF)
|
||||
|
||||
if (APPLE)
|
||||
option(WHISPER_NO_ACCELERATE "whisper: disable Accelerate framework" OFF)
|
||||
option(WHISPER_METAL "whisper: use Metal" ${WHISPER_METAL_DEFAULT})
|
||||
option(WHISPER_METAL_NDEBUG "whisper: disable Metal debugging" OFF)
|
||||
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
|
||||
option(WHISPER_COREML_ALLOW_FALLBACK "whisper: allow non-CoreML fallback" OFF)
|
||||
else()
|
||||
@ -113,6 +121,34 @@ if (APPLE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (WHISPER_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
|
||||
if (METAL_FRAMEWORK)
|
||||
message(STATUS "Metal framework found")
|
||||
|
||||
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS}
|
||||
${FOUNDATION_LIBRARY}
|
||||
${METAL_FRAMEWORK}
|
||||
${METALKIT_FRAMEWORK}
|
||||
)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_METAL)
|
||||
|
||||
if (WHISPER_METAL_NDEBUG)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_METAL_NDEBUG)
|
||||
endif()
|
||||
else()
|
||||
message(WARNING "Metal framework not found")
|
||||
endif()
|
||||
|
||||
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
|
||||
|
||||
# copy ggml-metal.metal to bin directory
|
||||
configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY)
|
||||
endif()
|
||||
|
||||
if (WHISPER_COREML)
|
||||
find_library(FOUNDATION_FRAMEWORK Foundation)
|
||||
find_library(COREML_FRAMEWORK CoreML)
|
||||
@ -177,7 +213,7 @@ if (WHISPER_CUBLAS)
|
||||
|
||||
enable_language(CUDA)
|
||||
|
||||
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
||||
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CUBLAS)
|
||||
|
||||
@ -228,7 +264,7 @@ if (WHISPER_CLBLAST)
|
||||
if (CLBlast_FOUND)
|
||||
message(STATUS "CLBlast found")
|
||||
|
||||
set(GGML_OPENCL_SOURCES ggml-opencl.cpp ggml-opencl.h)
|
||||
set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CLBLAST)
|
||||
|
||||
@ -321,6 +357,53 @@ else()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# POSIX conformance
|
||||
#
|
||||
|
||||
# clock_gettime came in POSIX.1b (1993)
|
||||
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
||||
# posix_memalign came in POSIX.1-2001 / SUSv3
|
||||
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
|
||||
add_compile_definitions(_XOPEN_SOURCE=600)
|
||||
|
||||
# Somehow in OpenBSD whenever POSIX conformance is specified
|
||||
# some string functions rely on locale_t availability,
|
||||
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||
remove_definitions(-D_XOPEN_SOURCE=600)
|
||||
add_compile_definitions(_XOPEN_SOURCE=700)
|
||||
endif()
|
||||
|
||||
# Data types, macros and functions related to controlling CPU affinity
|
||||
# are available on Linux through GNU extensions in libc
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
add_compile_definitions(_GNU_SOURCE)
|
||||
endif()
|
||||
|
||||
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
||||
# and on macOS its availability depends on enabling Darwin extensions
|
||||
# similarly on DragonFly, enabling BSD extensions is necessary
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
add_compile_definitions(_DARWIN_C_SOURCE)
|
||||
endif()
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "DragonFly")
|
||||
add_compile_definitions(_DARWIN_C_SOURCE)
|
||||
endif()
|
||||
|
||||
# alloca is a non-standard interface that is not visible on BSDs when
|
||||
# POSIX conformance is specified, but not all of them provide a clean way
|
||||
# to enable it in such cases
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
|
||||
add_compile_definitions(__BSD_VISIBLE)
|
||||
endif()
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
|
||||
add_compile_definitions(_NETBSD_SOURCE)
|
||||
endif()
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||
add_compile_definitions(_BSD_SOURCE)
|
||||
endif()
|
||||
|
||||
if (WHISPER_PERF)
|
||||
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
|
||||
endif()
|
||||
@ -379,8 +462,11 @@ set(TARGET whisper)
|
||||
add_library(${TARGET}
|
||||
ggml.h
|
||||
ggml.c
|
||||
${GGML_CUDA_SOURCES}
|
||||
${GGML_OPENCL_SOURCES}
|
||||
ggml-alloc.h
|
||||
ggml-alloc.c
|
||||
${GGML_SOURCES_METAL}
|
||||
${GGML_SOURCES_CUDA}
|
||||
${GGML_SOURCES_OPENCL}
|
||||
whisper.h
|
||||
whisper.cpp
|
||||
)
|
||||
@ -421,9 +507,15 @@ if (BUILD_SHARED_LIBS)
|
||||
WHISPER_BUILD
|
||||
GGML_BUILD
|
||||
)
|
||||
|
||||
if (WHISPER_METAL)
|
||||
# TODO: I think this should make ggml-metal.m "see" the ggml-metal.metal file from the "bin" directory
|
||||
# but for some reason it does not work here like it does in llama.cpp
|
||||
set_target_properties(${TARGET} PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_SOURCES)
|
||||
if (GGML_SOURCES_CUDA)
|
||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||
set_property(TARGET whisper PROPERTY CUDA_ARCHITECTURES OFF)
|
||||
set_property(TARGET whisper PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||
@ -439,10 +531,13 @@ target_compile_definitions(${TARGET} PUBLIC
|
||||
|
||||
set_target_properties(${TARGET} PROPERTIES PUBLIC_HEADER "whisper.h")
|
||||
|
||||
include(GNUInstallDirs)
|
||||
|
||||
install(TARGETS ${TARGET}
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib/static
|
||||
RUNTIME DESTINATION bin
|
||||
RESOURCE DESTINATION bin
|
||||
PUBLIC_HEADER DESTINATION include
|
||||
)
|
||||
|
||||
|
74
Makefile
74
Makefile
@ -18,7 +18,7 @@ ifndef NVCC_VERSION
|
||||
endif
|
||||
endif
|
||||
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CCV := $(shell $(CC) --version | head -n 1)
|
||||
CXXV := $(shell $(CXX) --version | head -n 1)
|
||||
|
||||
# Mac OS + Arm can report x86_64
|
||||
@ -42,18 +42,55 @@ CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
|
||||
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
|
||||
LDFLAGS =
|
||||
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
|
||||
ifneq ($(wildcard /usr/include/musl/*),)
|
||||
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
||||
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
||||
# clock_gettime came in POSIX.1b (1993)
|
||||
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
||||
# posix_memalign came in POSIX.1-2001 / SUSv3
|
||||
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
|
||||
CFLAGS += -D_XOPEN_SOURCE=600
|
||||
CXXFLAGS += -D_XOPEN_SOURCE=600
|
||||
|
||||
# Somehow in OpenBSD whenever POSIX conformance is specified
|
||||
# some string functions rely on locale_t availability,
|
||||
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||
ifeq ($(UNAME_S),OpenBSD)
|
||||
CFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700
|
||||
CXXFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700
|
||||
endif
|
||||
|
||||
# Data types, macros and functions related to controlling CPU affinity
|
||||
# are available on Linux through GNU extensions in libc
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
CFLAGS += -D_GNU_SOURCE
|
||||
CXXFLAGS += -D_GNU_SOURCE
|
||||
endif
|
||||
|
||||
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
||||
# and on macOS its availability depends on enabling Darwin extensions
|
||||
# similarly on DragonFly, enabling BSD extensions is necessary
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CFLAGS += -D_DARWIN_C_SOURCE
|
||||
CXXFLAGS += -D_DARWIN_C_SOURCE
|
||||
endif
|
||||
ifeq ($(UNAME_S),DragonFly)
|
||||
CFLAGS += -D__BSD_VISIBLE
|
||||
CXXFLAGS += -D__BSD_VISIBLE
|
||||
endif
|
||||
|
||||
# alloca is a non-standard interface that is not visible on BSDs when
|
||||
# POSIX conformance is specified, but not all of them provide a clean way
|
||||
# to enable it in such cases
|
||||
ifeq ($(UNAME_S),FreeBSD)
|
||||
CFLAGS += -D__BSD_VISIBLE
|
||||
CXXFLAGS += -D__BSD_VISIBLE
|
||||
endif
|
||||
ifeq ($(UNAME_S),NetBSD)
|
||||
CFLAGS += -D_NETBSD_SOURCE
|
||||
CXXFLAGS += -D_NETBSD_SOURCE
|
||||
endif
|
||||
ifeq ($(UNAME_S),OpenBSD)
|
||||
CFLAGS += -D_BSD_SOURCE
|
||||
CXXFLAGS += -D_BSD_SOURCE
|
||||
endif
|
||||
|
||||
# OS specific
|
||||
# TODO: support Windows
|
||||
@ -67,7 +104,7 @@ endif
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
CPUINFO_CMD := sysctl machdep.cpu.features
|
||||
CPUINFO_CMD := sysctl machdep.cpu.features machdep.cpu.leaf7_features
|
||||
else ifeq ($(UNAME_S),Linux)
|
||||
CPUINFO_CMD := cat /proc/cpuinfo
|
||||
else ifneq (,$(filter MINGW32_NT% MINGW64_NT%,$(UNAME_S)))
|
||||
@ -145,6 +182,15 @@ ifdef WHISPER_COREML_ALLOW_FALLBACK
|
||||
endif
|
||||
endif
|
||||
|
||||
ifndef WHISPER_NO_METAL
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
WHISPER_METAL := 1
|
||||
|
||||
CXXFLAGS += -DGGML_USE_METAL
|
||||
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||
endif
|
||||
endif
|
||||
|
||||
ifdef WHISPER_OPENBLAS
|
||||
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas
|
||||
LDFLAGS += -lopenblas
|
||||
@ -251,6 +297,11 @@ $(info )
|
||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
WHISPER_OBJ += ggml-alloc.o
|
||||
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
@ -266,6 +317,13 @@ whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-imp
|
||||
WHISPER_OBJ += whisper.o whisper-encoder.o whisper-encoder-impl.o
|
||||
endif
|
||||
|
||||
ifdef WHISPER_METAL
|
||||
ggml-metal.o: ggml-metal.m ggml-metal.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
WHISPER_OBJ += ggml-metal.o
|
||||
endif
|
||||
|
||||
libwhisper.a: ggml.o $(WHISPER_OBJ)
|
||||
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
|
||||
|
||||
@ -297,8 +355,8 @@ quantize: examples/quantize/quantize.cpp ggml.o $(WHISPER_OBJ) $(SRC_COMMON)
|
||||
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
command: examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp examples/grammar-parser.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
|
||||
|
||||
lsp: examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
|
||||
$(CXX) $(CXXFLAGS) examples/lsp/lsp.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o lsp $(CC_SDL) $(LDFLAGS)
|
||||
|
@ -11,14 +11,14 @@ Beta: [v1.4.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.4.2) / S
|
||||
High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model:
|
||||
|
||||
- Plain C/C++ implementation without dependencies
|
||||
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate framework and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
||||
- Apple Silicon first-class citizen - optimized via ARM NEON, Accelerate framework, Metal and [Core ML](https://github.com/ggerganov/whisper.cpp#core-ml-support)
|
||||
- AVX intrinsics support for x86 architectures
|
||||
- VSX intrinsics support for POWER architectures
|
||||
- Mixed F16 / F32 precision
|
||||
- [4-bit and 5-bit integer quantization support](https://github.com/ggerganov/whisper.cpp#quantization)
|
||||
- Low memory usage (Flash Attention)
|
||||
- Zero memory allocations at runtime
|
||||
- Runs on the CPU
|
||||
- Support for CPU-only inference
|
||||
- [Partial GPU support for NVIDIA via cuBLAS](https://github.com/ggerganov/whisper.cpp#nvidia-gpu-support-via-cublas)
|
||||
- [Partial OpenCL GPU support via CLBlast](https://github.com/ggerganov/whisper.cpp#opencl-gpu-support-via-clblast)
|
||||
- [BLAS CPU support via OpenBLAS](https://github.com/ggerganov/whisper.cpp#blas-cpu-support-via-openblas)
|
||||
@ -50,6 +50,10 @@ You can also easily make your own offline voice assistant application: [command]
|
||||
|
||||
https://user-images.githubusercontent.com/1991296/204038393-2f846eae-c255-4099-a76d-5735c25c49da.mp4
|
||||
|
||||
On Apply Silicon, the inference runs fully on the GPU via Metal:
|
||||
|
||||
https://github.com/ggerganov/whisper.cpp/assets/1991296/c82e8f86-60dc-49f2-b048-d2fdbd6b5225
|
||||
|
||||
Or you can even run it straight in the browser: [talk.wasm](examples/talk.wasm)
|
||||
|
||||
## Implementation details
|
||||
|
Submodule bindings/ios updated: de46d9e781...22a9eef021
@ -2,6 +2,7 @@ plugins {
|
||||
id 'java'
|
||||
id 'java-library'
|
||||
id 'maven-publish'
|
||||
id 'signing'
|
||||
}
|
||||
|
||||
archivesBaseName = 'whispercpp'
|
||||
@ -109,4 +110,23 @@ publishing {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
repositories {
|
||||
maven {
|
||||
def releasesRepoUrl = 'https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/'
|
||||
def snapshotsRepoUrl = 'https://s01.oss.sonatype.org/content/repositories/snapshots/'
|
||||
url = version.endsWith('-SNAPSHOT') ? snapshotsRepoUrl : releasesRepoUrl
|
||||
credentials {
|
||||
username = System.getenv("MAVEN_USERNAME")
|
||||
password = System.getenv("MAVEN_PASSWORD")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
signing {
|
||||
def signingKey = System.getenv("PGP_SECRET")
|
||||
def signingPassword = System.getenv("PGP_PASSPHRASE")
|
||||
useInMemoryPgpKeys(signingKey, signingPassword)
|
||||
sign publishing.publications.mavenJava
|
||||
}
|
||||
|
@ -22,7 +22,13 @@ struct whisper_coreml_context * whisper_coreml_init(const char * path_model) {
|
||||
|
||||
NSURL * url_model = [NSURL fileURLWithPath: path_model_str];
|
||||
|
||||
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
|
||||
// select which device to run the Core ML model on
|
||||
MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
|
||||
config.computeUnits = MLComputeUnitsCPUAndGPU;
|
||||
//config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
|
||||
//config.computeUnits = MLComputeUnitsAll;
|
||||
|
||||
const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);
|
||||
|
||||
if (data == NULL) {
|
||||
return NULL;
|
||||
|
@ -23,7 +23,6 @@ add_library(${TARGET} STATIC
|
||||
common.cpp
|
||||
common-ggml.h
|
||||
common-ggml.cpp
|
||||
grammar-parser.cpp
|
||||
)
|
||||
|
||||
include(DefaultTargetOptions)
|
||||
|
@ -44,13 +44,13 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
||||
fprintf(stderr, " %-7s 0 - whisper encoder\n", "");
|
||||
fprintf(stderr, " %-7s 0 - whisper\n", "");
|
||||
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
||||
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
int whisper_bench_encoder(const whisper_params & params) {
|
||||
int whisper_bench_full(const whisper_params & params) {
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
@ -69,12 +69,49 @@ int whisper_bench_encoder(const whisper_params & params) {
|
||||
fprintf(stderr, "error: failed to set mel: %d\n", ret);
|
||||
return 3;
|
||||
}
|
||||
|
||||
// heat encoder
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
whisper_token tokens[512];
|
||||
memset(tokens, 0, sizeof(tokens));
|
||||
|
||||
// prompt heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
// text-generation heat
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
whisper_reset_timings(ctx);
|
||||
|
||||
// actual run
|
||||
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < 256; i++) {
|
||||
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
|
||||
fprintf(stderr, "error: failed to encode model: %d\n", ret);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
whisper_free(ctx);
|
||||
|
||||
@ -103,7 +140,7 @@ int main(int argc, char ** argv) {
|
||||
int ret = -1;
|
||||
|
||||
switch (params.what) {
|
||||
case 0: ret = whisper_bench_encoder(params); break;
|
||||
case 0: ret = whisper_bench_full(params); break;
|
||||
case 1: ret = whisper_bench_memcpy(params.n_threads); break;
|
||||
case 2: ret = whisper_bench_ggml_mul_mat(params.n_threads); break;
|
||||
default: fprintf(stderr, "error: unknown benchmark: %d\n", params.what); break;
|
||||
|
@ -6,10 +6,9 @@
|
||||
// ref: https://github.com/ggerganov/whisper.cpp/issues/171
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <cassert>
|
||||
@ -22,11 +21,6 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
bool file_exists(const std::string & fname) {
|
||||
std::ifstream f(fname.c_str());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -36,12 +30,8 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
@ -54,8 +44,6 @@ struct whisper_params {
|
||||
std::string fname_out;
|
||||
std::string commands;
|
||||
std::string prompt;
|
||||
std::string context;
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||
@ -85,9 +73,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
||||
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
|
||||
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -121,30 +106,16 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
||||
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
|
||||
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
|
||||
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string & grammar_rule,
|
||||
float & logprob_min,
|
||||
float & logprob_sum,
|
||||
int & n_tokens,
|
||||
int64_t & t_ms) {
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
logprob_min = 0.0f;
|
||||
logprob_sum = 0.0f;
|
||||
n_tokens = 0;
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
wparams.print_special = params.print_special;
|
||||
@ -152,37 +123,19 @@ std::string transcribe(
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.translate = params.translate;
|
||||
wparams.no_context = true;
|
||||
wparams.no_timestamps = params.no_timestamps;
|
||||
wparams.single_segment = true;
|
||||
wparams.max_tokens = params.max_tokens;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.temperature = 0.4f;
|
||||
wparams.temperature_inc = 1.0f;
|
||||
wparams.greedy.best_of = 5;
|
||||
|
||||
wparams.beam_search.beam_size = 5;
|
||||
|
||||
wparams.initial_prompt = params.context.data();
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
int prob_n = 0;
|
||||
std::string result;
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
@ -191,17 +144,19 @@ std::string transcribe(
|
||||
|
||||
result += text;
|
||||
|
||||
const int n = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
const auto token = whisper_full_get_token_data(ctx, i, j);
|
||||
|
||||
if(token.plog > 0.0f) exit(0);
|
||||
logprob_min = std::min(logprob_min, token.plog);
|
||||
logprob_sum += token.plog;
|
||||
++n_tokens;
|
||||
prob += token.p;
|
||||
++prob_n;
|
||||
}
|
||||
}
|
||||
|
||||
if (prob_n > 0) {
|
||||
prob /= prob_n;
|
||||
}
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
||||
|
||||
@ -292,7 +247,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
@ -460,9 +415,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
bool is_running = true;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
int n_tokens = 0;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
|
||||
@ -500,7 +453,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
@ -536,27 +489,18 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
|
||||
float logprob_min0 = 0.0f;
|
||||
float logprob_min = 0.0f;
|
||||
|
||||
float logprob_sum0 = 0.0f;
|
||||
float logprob_sum = 0.0f;
|
||||
|
||||
int n_tokens0 = 0;
|
||||
int n_tokens = 0;
|
||||
float prob0 = 0.0f;
|
||||
float prob = 0.0f;
|
||||
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
if (!params.prompt.empty()) {
|
||||
k_prompt = params.prompt;
|
||||
}
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@ -589,11 +533,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min0);
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
@ -614,30 +556,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
||||
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (size_t n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= txt.size()) {
|
||||
break;
|
||||
}
|
||||
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
const auto prompt = txt.substr(0, n);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
@ -650,16 +581,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
||||
if (best_len == 0) {
|
||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||
} else {
|
||||
// cut the prompt from the decoded text
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
}
|
||||
const std::string command = ::trim(txt.substr(best_len));
|
||||
|
||||
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
||||
fprintf(stdout, "\n");
|
||||
}
|
||||
|
||||
@ -724,36 +648,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
auto & grammar = params.grammar_parsed;
|
||||
if (file_exists(params.grammar.c_str())) {
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
} else {
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
}
|
||||
|
||||
audio.pause();
|
||||
|
@ -792,7 +792,7 @@ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void sam_print_usage(int argc, char ** argv, const sam_params & params) {
|
||||
void sam_print_usage(int /*argc*/, char ** argv, const sam_params & params) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
|
@ -1,423 +0,0 @@
|
||||
#include "grammar-parser.h"
|
||||
#include <cstdint>
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <stdexcept>
|
||||
#include <exception>
|
||||
|
||||
namespace grammar_parser {
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
// copied from whisper.cpp
|
||||
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||
uint8_t highbits = first_byte >> 4;
|
||||
int len = lookup[highbits];
|
||||
uint8_t mask = (1 << (8 - len)) - 1;
|
||||
uint32_t value = first_byte & mask;
|
||||
const char * end = src + len; // may overrun!
|
||||
const char * pos = src + 1;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||
return result.first->second;
|
||||
}
|
||||
|
||||
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||
return next_id;
|
||||
}
|
||||
|
||||
void add_rule(
|
||||
parse_state & state,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule) {
|
||||
if (state.rules.size() <= rule_id) {
|
||||
state.rules.resize(rule_id + 1);
|
||||
}
|
||||
state.rules[rule_id] = rule;
|
||||
}
|
||||
|
||||
bool is_word_char(char c) {
|
||||
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
const char * pos = src;
|
||||
const char * end = src + size;
|
||||
uint32_t value = 0;
|
||||
for ( ; pos < end && *pos; pos++) {
|
||||
value <<= 4;
|
||||
char c = *pos;
|
||||
if ('a' <= c && c <= 'f') {
|
||||
value += c - 'a' + 10;
|
||||
} else if ('A' <= c && c <= 'F') {
|
||||
value += c - 'A' + 10;
|
||||
} else if ('0' <= c && c <= '9') {
|
||||
value += c - '0';
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pos != end) {
|
||||
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||
}
|
||||
return std::make_pair(value, pos);
|
||||
}
|
||||
|
||||
const char * parse_space(const char * src, bool newline_ok) {
|
||||
const char * pos = src;
|
||||
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||
if (*pos == '#') {
|
||||
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||
pos++;
|
||||
}
|
||||
} else {
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_name(const char * src) {
|
||||
const char * pos = src;
|
||||
while (is_word_char(*pos)) {
|
||||
pos++;
|
||||
}
|
||||
if (pos == src) {
|
||||
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||
if (*src == '\\') {
|
||||
switch (src[1]) {
|
||||
case 'x': return parse_hex(src + 2, 2);
|
||||
case 'u': return parse_hex(src + 2, 4);
|
||||
case 'U': return parse_hex(src + 2, 8);
|
||||
case 't': return std::make_pair('\t', src + 2);
|
||||
case 'r': return std::make_pair('\r', src + 2);
|
||||
case 'n': return std::make_pair('\n', src + 2);
|
||||
case '\\':
|
||||
case '"':
|
||||
case '[':
|
||||
case ']':
|
||||
return std::make_pair(src[1], src + 2);
|
||||
default:
|
||||
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||
}
|
||||
} else if (*src) {
|
||||
return decode_utf8(src);
|
||||
}
|
||||
throw std::runtime_error("unexpected end of input");
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested);
|
||||
|
||||
const char * parse_sequence(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
std::vector<whisper_grammar_element> & out_elements,
|
||||
bool is_nested) {
|
||||
size_t last_sym_start = out_elements.size();
|
||||
const char * pos = src;
|
||||
while (*pos) {
|
||||
if (*pos == '"') { // literal string
|
||||
pos++;
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != '"') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR, char_pair.first});
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '[') { // char range(s)
|
||||
pos++;
|
||||
enum whisper_gretype start_type = WHISPER_GRETYPE_CHAR;
|
||||
if (*pos == '^') {
|
||||
pos++;
|
||||
start_type = WHISPER_GRETYPE_CHAR_NOT;
|
||||
}
|
||||
last_sym_start = out_elements.size();
|
||||
while (*pos != ']') {
|
||||
auto char_pair = parse_char(pos);
|
||||
pos = char_pair.second;
|
||||
enum whisper_gretype type = last_sym_start < out_elements.size()
|
||||
? WHISPER_GRETYPE_CHAR_ALT
|
||||
: start_type;
|
||||
|
||||
out_elements.push_back({type, char_pair.first});
|
||||
if (pos[0] == '-' && pos[1] != ']') {
|
||||
auto endchar_pair = parse_char(pos + 1);
|
||||
pos = endchar_pair.second;
|
||||
out_elements.push_back({WHISPER_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||
}
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (is_word_char(*pos)) { // rule reference
|
||||
const char * name_end = parse_name(pos);
|
||||
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||
pos = parse_space(name_end, is_nested);
|
||||
last_sym_start = out_elements.size();
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, ref_rule_id});
|
||||
} else if (*pos == '(') { // grouping
|
||||
// parse nested alternates into synthesized rule
|
||||
pos = parse_space(pos + 1, true);
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||
last_sym_start = out_elements.size();
|
||||
// output reference to synthesized rule
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
if (*pos != ')') {
|
||||
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||
if (last_sym_start == out_elements.size()) {
|
||||
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
||||
}
|
||||
|
||||
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||
// rewrite rules:
|
||||
// S* --> S' ::= S S' |
|
||||
// S+ --> S' ::= S S' | S
|
||||
// S? --> S' ::= S |
|
||||
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||
std::vector<whisper_grammar_element> sub_rule;
|
||||
// add preceding symbol to generated rule
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
if (*pos == '*' || *pos == '+') {
|
||||
// cause generated rule to recurse
|
||||
sub_rule.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
}
|
||||
// mark start of alternate def
|
||||
sub_rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
if (*pos == '+') {
|
||||
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
sub_rule.insert(
|
||||
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
}
|
||||
sub_rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, sub_rule_id, sub_rule);
|
||||
|
||||
// in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start);
|
||||
out_elements.push_back({WHISPER_GRETYPE_RULE_REF, sub_rule_id});
|
||||
|
||||
pos = parse_space(pos + 1, is_nested);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_alternates(
|
||||
parse_state & state,
|
||||
const char * src,
|
||||
const std::string & rule_name,
|
||||
uint32_t rule_id,
|
||||
bool is_nested) {
|
||||
std::vector<whisper_grammar_element> rule;
|
||||
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||
while (*pos == '|') {
|
||||
rule.push_back({WHISPER_GRETYPE_ALT, 0});
|
||||
pos = parse_space(pos + 1, true);
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||
}
|
||||
rule.push_back({WHISPER_GRETYPE_END, 0});
|
||||
add_rule(state, rule_id, rule);
|
||||
return pos;
|
||||
}
|
||||
|
||||
const char * parse_rule(parse_state & state, const char * src) {
|
||||
const char * name_end = parse_name(src);
|
||||
const char * pos = parse_space(name_end, false);
|
||||
size_t name_len = name_end - src;
|
||||
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||
const std::string name(src, name_len);
|
||||
|
||||
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||
}
|
||||
pos = parse_space(pos + 3, true);
|
||||
|
||||
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||
|
||||
if (*pos == '\r') {
|
||||
pos += pos[1] == '\n' ? 2 : 1;
|
||||
} else if (*pos == '\n') {
|
||||
pos++;
|
||||
} else if (*pos) {
|
||||
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||
}
|
||||
return parse_space(pos, true);
|
||||
}
|
||||
|
||||
parse_state parse(const char * src) {
|
||||
try {
|
||||
parse_state state;
|
||||
const char * pos = parse_space(src, true);
|
||||
while (*pos) {
|
||||
pos = parse_rule(state, pos);
|
||||
}
|
||||
return state;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||
return parse_state();
|
||||
}
|
||||
}
|
||||
|
||||
void print_grammar_char(FILE * file, uint32_t c) {
|
||||
if (0x20 <= c && c <= 0x7f) {
|
||||
fprintf(file, "%c", static_cast<char>(c));
|
||||
} else {
|
||||
// cop out of encoding UTF-8
|
||||
fprintf(file, "<U+%04X>", c);
|
||||
}
|
||||
}
|
||||
|
||||
bool is_char_element(whisper_grammar_element elem) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_CHAR: return true;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: return true;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
void print_rule_binary(FILE * file, const std::vector<whisper_grammar_element> & rule) {
|
||||
for (auto elem : rule) {
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END: fprintf(file, "END"); break;
|
||||
case WHISPER_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||
case WHISPER_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||
case WHISPER_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||
}
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "(%u) ", elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
fprintf(file, "(\"");
|
||||
print_grammar_char(file, elem.value);
|
||||
fprintf(file, "\") ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_rule(
|
||||
FILE * file,
|
||||
uint32_t rule_id,
|
||||
const std::vector<whisper_grammar_element> & rule,
|
||||
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||
if (rule.empty() || rule.back().type != WHISPER_GRETYPE_END) {
|
||||
throw std::runtime_error(
|
||||
"malformed rule, does not end with WHISPER_GRETYPE_END: " + std::to_string(rule_id));
|
||||
}
|
||||
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
whisper_grammar_element elem = rule[i];
|
||||
switch (elem.type) {
|
||||
case WHISPER_GRETYPE_END:
|
||||
throw std::runtime_error(
|
||||
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||
std::to_string(i));
|
||||
case WHISPER_GRETYPE_ALT:
|
||||
fprintf(file, "| ");
|
||||
break;
|
||||
case WHISPER_GRETYPE_RULE_REF:
|
||||
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR:
|
||||
fprintf(file, "[");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_NOT:
|
||||
fprintf(file, "[^");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
fprintf(file, "-");
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||
throw std::runtime_error(
|
||||
"WHISPER_GRETYPE_CHAR_ALT without preceding char: " +
|
||||
std::to_string(rule_id) + "," + std::to_string(i));
|
||||
}
|
||||
print_grammar_char(file, elem.value);
|
||||
break;
|
||||
}
|
||||
if (is_char_element(elem)) {
|
||||
switch (rule[i + 1].type) {
|
||||
case WHISPER_GRETYPE_CHAR_ALT:
|
||||
case WHISPER_GRETYPE_CHAR_RNG_UPPER:
|
||||
break;
|
||||
default:
|
||||
fprintf(file, "] ");
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(file, "\n");
|
||||
}
|
||||
|
||||
void print_grammar(FILE * file, const parse_state & state) {
|
||||
try {
|
||||
std::map<uint32_t, std::string> symbol_id_names;
|
||||
for (auto kv : state.symbol_ids) {
|
||||
symbol_id_names[kv.second] = kv.first;
|
||||
}
|
||||
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||
// fprintf(file, "%zu: ", i);
|
||||
// print_rule_binary(file, state.rules[i]);
|
||||
print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
|
||||
// fprintf(file, "\n");
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
||||
std::vector<const whisper_grammar_element *> ret;
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||
// binary context-free grammar format specified by whisper.h. Supports character
|
||||
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||
// arithmetic might look like:
|
||||
//
|
||||
// root ::= expr
|
||||
// expr ::= term ([-+*/] term)*
|
||||
// term ::= num | "(" space expr ")" space
|
||||
// num ::= [0-9]+ space
|
||||
// space ::= [ \t\n]*
|
||||
|
||||
#pragma once
|
||||
#include "whisper.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace grammar_parser {
|
||||
struct parse_state {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
|
||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
void print_grammar(FILE * file, const parse_state & state);
|
||||
}
|
@ -324,12 +324,12 @@ json register_commandset(struct whisper_context * ctx, json jparams, std::vector
|
||||
commandset_list.push_back(cs);
|
||||
return json{{"index",index}};
|
||||
}
|
||||
json seek(struct whisper_context * ctx, audio_async &audio, json params) {
|
||||
json seek(struct whisper_context * /*ctx*/, audio_async & /*audio*/, json /*params*/) {
|
||||
// whisper_state has the pertinent offsets, but there also seem to be a large
|
||||
// number of scratch buffers that would prevent rewinding context in a manner similar to llama
|
||||
// I'll give this a another pass once everything else is implemented,
|
||||
// but for now, it's unsupported
|
||||
throw json{
|
||||
throw json {
|
||||
{"code", -32601},
|
||||
{"message", "Seeking is not yet supported."}
|
||||
};
|
||||
@ -412,7 +412,7 @@ void process_loop(struct whisper_context * ctx, audio_async &audio, const whispe
|
||||
jobqueue.pop_front();
|
||||
// send response
|
||||
std::string data = resp.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
fprintf(stdout, "Content-Length: %d\r\n\r\n%s\n", data.length()+1, data.c_str());
|
||||
fprintf(stdout, "Content-Length: %d\r\n\r\n%s\n", (int)data.length()+1, data.c_str());
|
||||
std::cout.flush();
|
||||
|
||||
}
|
||||
|
@ -260,7 +260,7 @@ std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s
|
||||
|
||||
return speaker;
|
||||
}
|
||||
void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) {
|
||||
void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
|
||||
int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step;
|
||||
int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev);
|
||||
if (progress >= *progress_prev + progress_step) {
|
||||
@ -492,7 +492,7 @@ bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector<std::vector<float>> /*pcmf32s*/) {
|
||||
std::ofstream fout(fname);
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
|
@ -3,8 +3,8 @@
|
||||
// A very quick-n-dirty implementation serving mainly as a proof of concept.
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
|
||||
#include <cassert>
|
||||
|
@ -7,7 +7,7 @@ if (WHISPER_SDL2)
|
||||
|
||||
# TODO: this is temporary
|
||||
# need to export ggml symbols for MSVC, but too lazy ..
|
||||
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
|
||||
add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../ggml-alloc.c ../../whisper.cpp)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../)
|
||||
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
@ -1,11 +1,3 @@
|
||||
// Defines fileno on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#endif
|
||||
|
||||
#include "llama-util.h"
|
||||
#include "llama.h"
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "llama.h"
|
||||
|
||||
@ -649,7 +649,10 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, "\"", "");
|
||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "%s: failed to speak\n", __func__);
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
|
||||
|
@ -191,9 +191,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
|
@ -1,8 +1,8 @@
|
||||
// Talk with AI
|
||||
//
|
||||
|
||||
#include "common.h"
|
||||
#include "common-sdl.h"
|
||||
#include "common.h"
|
||||
#include "whisper.h"
|
||||
#include "gpt-2.h"
|
||||
|
||||
@ -349,7 +349,10 @@ int main(int argc, char ** argv) {
|
||||
gpt2_set_prompt(ctx_gpt, prompt_base.c_str());
|
||||
|
||||
text_to_speak = ::replace(text_to_speak, params.person + ": ", "");
|
||||
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
int ret = system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
||||
if (ret != 0) {
|
||||
fprintf(stderr, "%s: system() failed!\n", __func__);
|
||||
}
|
||||
|
||||
audio.clear();
|
||||
|
||||
|
2
examples/whisper.android/.idea/compiler.xml
generated
2
examples/whisper.android/.idea/compiler.xml
generated
@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CompilerConfiguration">
|
||||
<bytecodeTargetLevel target="11" />
|
||||
<bytecodeTargetLevel target="17" />
|
||||
</component>
|
||||
</project>
|
4
examples/whisper.android/.idea/gradle.xml
generated
4
examples/whisper.android/.idea/gradle.xml
generated
@ -4,15 +4,15 @@
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="GRADLE" />
|
||||
<option name="distributionType" value="DEFAULT_WRAPPED" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleJvm" value="#GRADLE_LOCAL_JAVA_HOME" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/app" />
|
||||
</set>
|
||||
</option>
|
||||
<option name="resolveExternalAnnotations" value="false" />
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
|
2
examples/whisper.android/.idea/misc.xml
generated
2
examples/whisper.android/.idea/misc.xml
generated
@ -1,7 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="true" project-jdk-name="Android Studio default JDK" project-jdk-type="JavaSDK">
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_17" default="true" project-jdk-name="jbr-17" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||
</component>
|
||||
<component name="ProjectType">
|
||||
|
@ -5,12 +5,12 @@ plugins {
|
||||
|
||||
android {
|
||||
namespace 'com.whispercppdemo'
|
||||
compileSdk 33
|
||||
compileSdk 34
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.whispercppdemo"
|
||||
minSdk 26
|
||||
targetSdk 32
|
||||
targetSdk 34
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
@ -31,19 +31,19 @@ android {
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
sourceCompatibility JavaVersion.VERSION_17
|
||||
targetCompatibility JavaVersion.VERSION_17
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
jvmTarget = '17'
|
||||
}
|
||||
buildFeatures {
|
||||
compose true
|
||||
}
|
||||
composeOptions {
|
||||
kotlinCompilerExtensionVersion '1.3.1'
|
||||
kotlinCompilerExtensionVersion '1.5.0'
|
||||
}
|
||||
ndkVersion "25.1.8937393"
|
||||
ndkVersion "25.2.9519653"
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path = file("src/main/jni/whisper/CMakeLists.txt")
|
||||
@ -57,19 +57,19 @@ android {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation 'androidx.activity:activity-compose:1.6.1'
|
||||
implementation 'androidx.compose.material:material-icons-core:1.3.1'
|
||||
implementation 'androidx.compose.material3:material3:1.0.1'
|
||||
implementation "androidx.compose.ui:ui:1.3.2"
|
||||
implementation "androidx.compose.ui:ui-tooling-preview:1.3.2"
|
||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.5.1'
|
||||
implementation 'androidx.activity:activity-compose:1.7.2'
|
||||
implementation 'androidx.compose.material:material-icons-core:1.5.0'
|
||||
implementation 'androidx.compose.material3:material3:1.1.1'
|
||||
implementation "androidx.compose.ui:ui:1.5.0"
|
||||
implementation "androidx.compose.ui:ui-tooling-preview:1.5.0"
|
||||
implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1'
|
||||
implementation "com.google.accompanist:accompanist-permissions:0.28.0"
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.6.4'
|
||||
implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-core:1.7.2'
|
||||
|
||||
testImplementation 'junit:junit:4.13.2'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
||||
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-tooling:1.3.2"
|
||||
debugImplementation "androidx.compose.ui:ui-test-manifest:1.3.2"
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
|
||||
androidTestImplementation "androidx.compose.ui:ui-test-junit4:1.5.0"
|
||||
debugImplementation "androidx.compose.ui:ui-tooling:1.5.0"
|
||||
debugImplementation "androidx.compose.ui:ui-test-manifest:1.5.0"
|
||||
}
|
@ -66,7 +66,7 @@ private fun MainScreen(
|
||||
|
||||
@Composable
|
||||
private fun MessageLog(log: String) {
|
||||
SelectionContainer() {
|
||||
SelectionContainer {
|
||||
Text(modifier = Modifier.verticalScroll(rememberScrollState()), text = log)
|
||||
}
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
|
||||
}
|
||||
|
||||
private suspend fun printSystemInfo() {
|
||||
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()));
|
||||
printMessage(String.format("System Info: %s\n", WhisperContext.getSystemInfo()))
|
||||
}
|
||||
|
||||
private suspend fun loadData() {
|
||||
|
@ -13,7 +13,7 @@ import androidx.compose.runtime.SideEffect
|
||||
import androidx.compose.ui.graphics.toArgb
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.platform.LocalView
|
||||
import androidx.core.view.ViewCompat
|
||||
import androidx.core.view.WindowCompat
|
||||
|
||||
private val DarkColorScheme = darkColorScheme(
|
||||
primary = Purple80,
|
||||
@ -55,8 +55,9 @@ fun WhisperCppDemoTheme(
|
||||
val view = LocalView.current
|
||||
if (!view.isInEditMode) {
|
||||
SideEffect {
|
||||
(view.context as Activity).window.statusBarColor = colorScheme.primary.toArgb()
|
||||
ViewCompat.getWindowInsetsController(view)?.isAppearanceLightStatusBars = darkTheme
|
||||
val window = (view.context as Activity).window
|
||||
window.statusBarColor = colorScheme.primary.toArgb()
|
||||
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,9 @@ class WhisperContext private constructor(private var ptr: Long) {
|
||||
|
||||
suspend fun transcribeData(data: FloatArray): String = withContext(scope.coroutineContext) {
|
||||
require(ptr != 0L)
|
||||
WhisperLib.fullTranscribe(ptr, data)
|
||||
val numThreads = WhisperCpuConfig.preferredThreadCount
|
||||
Log.d(LOG_TAG, "Selecting $numThreads threads")
|
||||
WhisperLib.fullTranscribe(ptr, numThreads, data)
|
||||
val textCount = WhisperLib.getTextSegmentCount(ptr)
|
||||
return@withContext buildString {
|
||||
for (i in 0 until textCount) {
|
||||
@ -126,7 +128,7 @@ private class WhisperLib {
|
||||
external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long
|
||||
external fun initContext(modelPath: String): Long
|
||||
external fun freeContext(contextPtr: Long)
|
||||
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
|
||||
external fun fullTranscribe(contextPtr: Long, numThreads: Int, audioData: FloatArray)
|
||||
external fun getTextSegmentCount(contextPtr: Long): Int
|
||||
external fun getTextSegment(contextPtr: Long, index: Int): String
|
||||
external fun getSystemInfo(): String
|
||||
|
@ -0,0 +1,73 @@
|
||||
package com.whispercppdemo.whisper
|
||||
|
||||
import android.util.Log
|
||||
import java.io.BufferedReader
|
||||
import java.io.FileReader
|
||||
|
||||
object WhisperCpuConfig {
|
||||
val preferredThreadCount: Int
|
||||
// Always use at least 2 threads:
|
||||
get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2)
|
||||
}
|
||||
|
||||
private class CpuInfo(private val lines: List<String>) {
|
||||
private fun getHighPerfCpuCount(): Int = try {
|
||||
getHighPerfCpuCountByFrequencies()
|
||||
} catch (e: Exception) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU frequencies", e)
|
||||
getHighPerfCpuCountByVariant()
|
||||
}
|
||||
|
||||
private fun getHighPerfCpuCountByFrequencies(): Int =
|
||||
getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) }
|
||||
.also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") }
|
||||
.countDroppingMin()
|
||||
|
||||
private fun getHighPerfCpuCountByVariant(): Int =
|
||||
getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) }
|
||||
.also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") }
|
||||
.countKeepingMin()
|
||||
|
||||
private fun List<Int>.binnedValues() = groupingBy { it }.eachCount()
|
||||
|
||||
private fun getCpuValues(property: String, mapper: (String) -> Int) = lines
|
||||
.asSequence()
|
||||
.filter { it.startsWith(property) }
|
||||
.map { mapper(it.substringAfter(':').trim()) }
|
||||
.sorted()
|
||||
.toList()
|
||||
|
||||
|
||||
private fun List<Int>.countDroppingMin(): Int {
|
||||
val min = min()
|
||||
return count { it > min }
|
||||
}
|
||||
|
||||
private fun List<Int>.countKeepingMin(): Int {
|
||||
val min = min()
|
||||
return count { it == min }
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val LOG_TAG = "WhisperCpuConfig"
|
||||
|
||||
fun getHighPerfCpuCount(): Int = try {
|
||||
readCpuInfo().getHighPerfCpuCount()
|
||||
} catch (e: Exception) {
|
||||
Log.d(LOG_TAG, "Couldn't read CPU info", e)
|
||||
// Our best guess -- just return the # of CPUs minus 4.
|
||||
(Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0)
|
||||
}
|
||||
|
||||
private fun readCpuInfo() = CpuInfo(
|
||||
BufferedReader(FileReader("/proc/cpuinfo"))
|
||||
.useLines { it.toList() }
|
||||
)
|
||||
|
||||
private fun getMaxCpuFrequency(cpuIndex: Int): Int {
|
||||
val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq"
|
||||
val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() }
|
||||
return maxFreq.toInt()
|
||||
}
|
||||
}
|
||||
}
|
@ -8,6 +8,7 @@ set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../../../../../)
|
||||
set(
|
||||
SOURCE_FILES
|
||||
${WHISPER_LIB_DIR}/ggml.c
|
||||
${WHISPER_LIB_DIR}/ggml-alloc.c
|
||||
${WHISPER_LIB_DIR}/whisper.cpp
|
||||
${CMAKE_SOURCE_DIR}/jni.c
|
||||
)
|
||||
@ -20,7 +21,7 @@ function(build_library target_name)
|
||||
SHARED
|
||||
${SOURCE_FILES}
|
||||
)
|
||||
|
||||
|
||||
target_link_libraries(${target_name} ${LOG_LIB} android)
|
||||
|
||||
if (${target_name} STREQUAL "whisper_v8fp16_va")
|
||||
|
@ -163,16 +163,12 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_freeContext(
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jfloatArray audio_data) {
|
||||
JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, jfloatArray audio_data) {
|
||||
UNUSED(thiz);
|
||||
struct whisper_context *context = (struct whisper_context *) context_ptr;
|
||||
jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL);
|
||||
const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data);
|
||||
|
||||
// Leave 2 processors free (i.e. the high-efficiency cores).
|
||||
int max_threads = max(1, min(8, get_nprocs() - 2));
|
||||
LOGI("Selecting %d threads", max_threads);
|
||||
|
||||
// The below adapted from the Objective-C iOS sample
|
||||
struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
params.print_realtime = true;
|
||||
@ -181,7 +177,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_fullTranscribe(
|
||||
params.print_special = false;
|
||||
params.translate = false;
|
||||
params.language = "en";
|
||||
params.n_threads = max_threads;
|
||||
params.n_threads = num_threads;
|
||||
params.offset_ms = 0;
|
||||
params.no_context = true;
|
||||
params.single_segment = false;
|
||||
|
@ -1,10 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
@ -1,6 +1,6 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
plugins {
|
||||
id 'com.android.application' version '7.3.1' apply false
|
||||
id 'com.android.library' version '7.3.1' apply false
|
||||
id 'org.jetbrains.kotlin.android' version '1.7.10' apply false
|
||||
id 'com.android.application' version '8.1.1' apply false
|
||||
id 'com.android.library' version '8.1.1' apply false
|
||||
id 'org.jetbrains.kotlin.android' version '1.9.0' apply false
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
#Wed Dec 14 10:37:24 EST 2022
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
|
||||
distributionPath=wrapper/dists
|
||||
zipStorePath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
|
@ -28,6 +28,8 @@ This can significantly improve the performance of the transcription:
|
||||
|
||||
<img width="1072" alt="image" src="https://user-images.githubusercontent.com/1991296/208511239-8d7cdbd1-aa48-41b5-becd-ca288d53cc07.png">
|
||||
|
||||
## Core ML
|
||||
|
||||
If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK` compiler flag for `whisper.cpp` in Build Phases:
|
||||
|
||||
<img width="1072" alt="image" src="https://github.com/ggerganov/whisper.cpp/assets/3001525/103e8f57-6eb6-490d-a60c-f6cf6c319324">
|
||||
@ -35,3 +37,13 @@ If you want to enable Core ML support, you can add the `-DWHISPER_USE_COREML -DW
|
||||
Then follow the [`Core ML support` section of readme](../../README.md#core-ml-support) for convert the model.
|
||||
|
||||
In this project, it also added `-O3 -DNDEBUG` to `Other C Flags`, but adding flags to app proj is not ideal in real world (applies to all C/C++ files), consider splitting xcodeproj in workspace in your own project.
|
||||
|
||||
## Metal
|
||||
|
||||
You can also enable Metal to make the inference run on the GPU of your device. This might or might not be more efficient
|
||||
compared to Core ML depending on the model and device that you use.
|
||||
|
||||
To enable Metal, just add `-DGGML_USE_METAL` instead off the `-DWHISPER_USE_COREML` flag and you are ready.
|
||||
This will make both the Encoder and the Decoder run on the GPU.
|
||||
|
||||
If you want to run the Encoder with Core ML and the Decoder with Metal then simply add both `-DWHISPER_USE_COREML -DGGML_USE_METAL` flags. That's all!
|
||||
|
@ -7,6 +7,9 @@
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 184447182AB211A2007D6BFE /* ggml-alloc.c */; };
|
||||
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 1844471B2AB21655007D6BFE /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; };
|
||||
184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1844471D2AB2195F007D6BFE /* ggml-metal.metal */; };
|
||||
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7A29052BDF00BD2A04 /* AppDelegate.m */; };
|
||||
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C7D29052BDF00BD2A04 /* SceneDelegate.m */; };
|
||||
18627C8129052BDF00BD2A04 /* ViewController.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8029052BDF00BD2A04 /* ViewController.m */; };
|
||||
@ -14,7 +17,7 @@
|
||||
18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; };
|
||||
18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; };
|
||||
18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; };
|
||||
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK"; }; };
|
||||
18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; };
|
||||
18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; };
|
||||
18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; };
|
||||
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FE342452A0C3FA20015A058 /* whisper-encoder-impl.m */; };
|
||||
@ -23,7 +26,24 @@
|
||||
7FE3424F2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc in Resources */ = {isa = PBXBuildFile; fileRef = 7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXCopyFilesBuildPhase section */
|
||||
184447202AB21B25007D6BFE /* CopyFiles */ = {
|
||||
isa = PBXCopyFilesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
dstPath = "";
|
||||
dstSubfolderSpec = 7;
|
||||
files = (
|
||||
184447212AB21B43007D6BFE /* ggml-metal.metal in CopyFiles */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXCopyFilesBuildPhase section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
184447182AB211A2007D6BFE /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = "ggml-alloc.c"; path = "../../../ggml-alloc.c"; sourceTree = "<group>"; };
|
||||
184447192AB211A2007D6BFE /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = "ggml-alloc.h"; path = "../../../ggml-alloc.h"; sourceTree = "<group>"; };
|
||||
1844471B2AB21655007D6BFE /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; name = "ggml-metal.m"; path = "../../../ggml-metal.m"; sourceTree = "<group>"; };
|
||||
1844471D2AB2195F007D6BFE /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; name = "ggml-metal.metal"; path = "../../../ggml-metal.metal"; sourceTree = "<group>"; };
|
||||
18627C7629052BDF00BD2A04 /* whisper.objc.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = whisper.objc.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
18627C7929052BDF00BD2A04 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
|
||||
18627C7A29052BDF00BD2A04 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
|
||||
@ -80,6 +100,10 @@
|
||||
18627C7829052BDF00BD2A04 /* whisper.objc */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
1844471D2AB2195F007D6BFE /* ggml-metal.metal */,
|
||||
1844471B2AB21655007D6BFE /* ggml-metal.m */,
|
||||
184447182AB211A2007D6BFE /* ggml-alloc.c */,
|
||||
184447192AB211A2007D6BFE /* ggml-alloc.h */,
|
||||
7FE3424E2A0C418A0015A058 /* ggml-base.en-encoder.mlmodelc */,
|
||||
7FE342442A0C3FA20015A058 /* coreml */,
|
||||
18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */,
|
||||
@ -126,6 +150,7 @@
|
||||
18627C7229052BDF00BD2A04 /* Sources */,
|
||||
18627C7329052BDF00BD2A04 /* Frameworks */,
|
||||
18627C7429052BDF00BD2A04 /* Resources */,
|
||||
184447202AB21B25007D6BFE /* CopyFiles */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
@ -194,8 +219,10 @@
|
||||
18627C9629052C5800BD2A04 /* ggml.c in Sources */,
|
||||
18627C7B29052BDF00BD2A04 /* AppDelegate.m in Sources */,
|
||||
7FE3424D2A0C3FA20015A058 /* whisper-decoder-impl.m in Sources */,
|
||||
1844471A2AB211A2007D6BFE /* ggml-alloc.c in Sources */,
|
||||
18627C8C29052BE000BD2A04 /* main.m in Sources */,
|
||||
18627C7E29052BDF00BD2A04 /* SceneDelegate.m in Sources */,
|
||||
1844471C2AB21655007D6BFE /* ggml-metal.m in Sources */,
|
||||
7FE3424B2A0C3FA20015A058 /* whisper-encoder-impl.m in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
|
@ -20,6 +20,7 @@
|
||||
0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; };
|
||||
0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; };
|
||||
0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; };
|
||||
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
@ -41,6 +42,8 @@
|
||||
0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = "<group>"; };
|
||||
0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = "<group>"; };
|
||||
0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = "<group>"; };
|
||||
18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = "<group>"; };
|
||||
18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
@ -124,6 +127,8 @@
|
||||
0AAC5DC529539E89003032C3 /* whisper.cpp */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
18AED47F2AB21F2B009D854F /* ggml-alloc.c */,
|
||||
18AED4802AB21F2B009D854F /* ggml-alloc.h */,
|
||||
0AAC5DC929539EB0003032C3 /* ggml.c */,
|
||||
0AAC5DCA29539EB0003032C3 /* ggml.h */,
|
||||
0AAC5DC729539EB0003032C3 /* whisper.cpp */,
|
||||
@ -242,6 +247,7 @@
|
||||
0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */,
|
||||
0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */,
|
||||
0AA7514E2953D958001EE061 /* Recorder.swift in Sources */,
|
||||
18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
@ -369,7 +375,7 @@
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = 3TZ9BM962G;
|
||||
DEVELOPMENT_TEAM = P8JZH34X63;
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
@ -410,7 +416,7 @@
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"whisper.swiftui.demo/Supporting files/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = 3TZ9BM962G;
|
||||
DEVELOPMENT_TEAM = P8JZH34X63;
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
|
@ -44,27 +44,26 @@ if [ "$encoder_only" -eq 0 ]; then
|
||||
printf "\n"
|
||||
fi
|
||||
|
||||
printf "| CPU | OS | Config | Model | Th | Load | Enc. | Commit |\n"
|
||||
printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ------ |\n"
|
||||
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
|
||||
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
||||
|
||||
for model in "${models[@]}"; do
|
||||
# run once to heat-up the cache
|
||||
./bench -m ./models/ggml-$model.bin -t $n_threads 2>/dev/null 1>/dev/null
|
||||
|
||||
# actual run
|
||||
# store stderr output in a variable in order to parse it later
|
||||
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
|
||||
ret=$?
|
||||
|
||||
# parse the output:
|
||||
load_time=$(echo "$output" | grep "load time" | awk '{print $5}')
|
||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $5}')
|
||||
encode_time=$(echo "$output" | grep "encode time" | awk '{print $11}')
|
||||
decode_time=$(echo "$output" | grep "decode time" | awk '{print $11}')
|
||||
prompt_time=$(echo "$output" | grep "prompt time" | awk '{print $11}')
|
||||
system_info=$(echo "$output" | grep "system_info")
|
||||
n_threads=$(echo "$output" | grep "system_info" | awk '{print $4}')
|
||||
|
||||
# floor to milliseconds
|
||||
load_time=${load_time%.*}
|
||||
encode_time=${encode_time%.*}
|
||||
#encode_time=${encode_time%.*}
|
||||
#decode_time=${decode_time%.*}
|
||||
#prompt_time=${prompt_time%.*}
|
||||
|
||||
config=""
|
||||
|
||||
@ -87,6 +86,6 @@ for model in "${models[@]}"; do
|
||||
commit=$(git rev-parse --short HEAD)
|
||||
|
||||
if [ $ret -eq 0 ]; then
|
||||
printf "| <todo> | <todo> | $config | $model | $n_threads | $load_time | $encode_time | $commit |\n"
|
||||
printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
|
||||
fi
|
||||
done
|
||||
|
@ -1,18 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
||||
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
||||
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
||||
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
||||
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
|
||||
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
||||
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
||||
cp -rpv ../ggml/examples/common.h ./examples/common.h
|
||||
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
|
||||
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
|
||||
cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
|
||||
cp -rpv ../ggml/src/ggml.c ./ggml.c
|
||||
cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c
|
||||
cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h
|
||||
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu
|
||||
cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h
|
||||
cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp
|
||||
cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h
|
||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m
|
||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal
|
||||
cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h
|
||||
cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h
|
||||
cp -rpv ../ggml/examples/common.h ./examples/common.h
|
||||
cp -rpv ../ggml/examples/common.cpp ./examples/common.cpp
|
||||
cp -rpv ../ggml/examples/common-ggml.h ./examples/common-ggml.h
|
||||
cp -rpv ../ggml/examples/common-ggml.cpp ./examples/common-ggml.cpp
|
||||
|
||||
cp -rpv ../ggml/examples/whisper/whisper.h ./whisper.h
|
||||
cp -rpv ../ggml/examples/whisper/whisper.cpp ./whisper.cpp
|
||||
|
187
ggml-alloc.c
187
ggml-alloc.c
@ -6,6 +6,26 @@
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
#include <unistd.h>
|
||||
#if defined(_POSIX_MAPPED_FILES)
|
||||
#include <sys/types.h>
|
||||
#include <sys/mman.h>
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <memoryapi.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
||||
@ -99,15 +119,28 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
return ggml_nbytes(tensor);
|
||||
|
||||
UNUSED(alloc);
|
||||
}
|
||||
|
||||
// check if a tensor is allocated by this buffer
|
||||
static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
|
||||
void * ptr = tensor->data;
|
||||
return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size;
|
||||
}
|
||||
|
||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
|
||||
GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
|
||||
#endif
|
||||
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
|
||||
size = aligned_offset(NULL, size, alloc->alignment);
|
||||
|
||||
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
||||
@ -131,14 +164,14 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
|
||||
if (best_fit_block == -1) {
|
||||
// the last block is our last resort
|
||||
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
||||
max_avail = MAX(max_avail, block->size);
|
||||
if (block->size >= size) {
|
||||
best_fit_block = alloc->n_free_blocks - 1;
|
||||
max_avail = MAX(max_avail, block->size);
|
||||
} else {
|
||||
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
||||
__func__, size, max_avail);
|
||||
GGML_ASSERT(!"not enough space in the buffer");
|
||||
return;
|
||||
return;
|
||||
}
|
||||
}
|
||||
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
||||
@ -173,17 +206,17 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
|
||||
}
|
||||
|
||||
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
||||
static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
||||
void * ptr = tensor->data;
|
||||
|
||||
if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
|
||||
if (ggml_allocr_is_own(alloc, tensor) == false) {
|
||||
// the tensor was not allocated in this buffer
|
||||
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
||||
// the easiest way to deal with this is just to ignore it
|
||||
return;
|
||||
}
|
||||
|
||||
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
||||
size_t size = ggml_allocr_get_alloc_size(alloc, tensor);
|
||||
size = aligned_offset(NULL, size, alloc->alignment);
|
||||
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
||||
|
||||
@ -277,17 +310,68 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment)
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// address and size of the buffer when measuring
|
||||
// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
|
||||
static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
|
||||
static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
|
||||
// OS specific functions to allocate and free uncommitted virtual memory
|
||||
static void * alloc_vmem(size_t size) {
|
||||
#if defined(_WIN32)
|
||||
return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS);
|
||||
#elif defined(_POSIX_MAPPED_FILES)
|
||||
void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0);
|
||||
if (ptr == MAP_FAILED) {
|
||||
return NULL;
|
||||
}
|
||||
return ptr;
|
||||
#else
|
||||
// use a fixed address for other platforms
|
||||
uintptr_t base_addr = (uintptr_t)-size - 0x100;
|
||||
return (void *)base_addr;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void free_vmem(void * base_addr, size_t size) {
|
||||
#if defined(_WIN32)
|
||||
VirtualFree(base_addr, 0, MEM_RELEASE);
|
||||
UNUSED(size);
|
||||
#elif defined(_POSIX_MAPPED_FILES)
|
||||
munmap(base_addr, size);
|
||||
#else
|
||||
// nothing to do
|
||||
UNUSED(base_addr);
|
||||
UNUSED(size);
|
||||
#endif
|
||||
}
|
||||
|
||||
// allocate uncommitted virtual memory to measure the size of the graph
|
||||
static void alloc_measure_vmem(void ** base_addr, size_t * size) {
|
||||
// 128GB for 64-bit, 1GB for 32-bit
|
||||
*size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37;
|
||||
do {
|
||||
*base_addr = alloc_vmem(*size);
|
||||
if (*base_addr != NULL) {
|
||||
AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr);
|
||||
return;
|
||||
}
|
||||
// try again with half the size
|
||||
*size /= 2;
|
||||
} while (*size > 0);
|
||||
|
||||
GGML_ASSERT(!"failed to allocate virtual memory for measure buffer");
|
||||
}
|
||||
|
||||
static void free_measure_vmem(void * base_addr, size_t size) {
|
||||
free_vmem(base_addr, size);
|
||||
}
|
||||
|
||||
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
||||
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
||||
|
||||
void * base_addr;
|
||||
size_t size;
|
||||
|
||||
alloc_measure_vmem(&base_addr, &size);
|
||||
|
||||
*alloc = (struct ggml_allocr){
|
||||
/*.data = */ MEASURE_BASE_ADDR,
|
||||
/*.size = */ MEASURE_MAX_SIZE,
|
||||
/*.data = */ base_addr,
|
||||
/*.size = */ size,
|
||||
/*.alignment = */ alignment,
|
||||
/*.n_free_blocks = */ 0,
|
||||
/*.free_blocks = */ {{0}},
|
||||
@ -307,6 +391,9 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
||||
}
|
||||
|
||||
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
||||
if (alloc->measure) {
|
||||
free_measure_vmem(alloc->data, alloc->size);
|
||||
}
|
||||
free(alloc);
|
||||
}
|
||||
|
||||
@ -316,11 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
|
||||
|
||||
//////////// compute graph allocator
|
||||
|
||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
||||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
||||
}
|
||||
|
||||
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
||||
if (a->type != b->type) {
|
||||
return false;
|
||||
@ -336,28 +418,6 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
|
||||
return true;
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_VIEW:
|
||||
return t->src[0];
|
||||
case GGML_OP_CPY:
|
||||
return t->src[1];
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
||||
struct ggml_tensor * parent = t;
|
||||
do {
|
||||
parent = get_view_parent(parent);
|
||||
} while (ggml_is_view(parent));
|
||||
return parent;
|
||||
}
|
||||
|
||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_SCALE:
|
||||
@ -365,7 +425,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
@ -375,10 +434,8 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_ADD_REL_POS:
|
||||
return true;
|
||||
|
||||
default:
|
||||
@ -390,24 +447,8 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
||||
struct hash_node * ht = alloc->hash_table;
|
||||
if (node->data == NULL) {
|
||||
if (ggml_is_view(node)) {
|
||||
size_t offset;
|
||||
switch(node->op) {
|
||||
case GGML_OP_VIEW:
|
||||
memcpy(&offset, node->op_params, sizeof(size_t));
|
||||
node->data = (char *) node->src[0]->data + offset;
|
||||
break;
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
node->data = node->src[0]->data;
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
node->data = node->src[1]->data;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(!"unknown view op");
|
||||
break;
|
||||
}
|
||||
assert(node->view_src->data != NULL);
|
||||
node->data = (char *)node->view_src->data + node->view_offs;
|
||||
} else {
|
||||
// see if we can reuse a parent's buffer (inplace)
|
||||
if (ggml_op_can_inplace(node->op)) {
|
||||
@ -418,8 +459,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
||||
}
|
||||
|
||||
// if the node's data is external, then we cannot re-use it
|
||||
if ((char *) parent->data < (char *) alloc->data ||
|
||||
(char *) parent->data >= ((char *) alloc->data + alloc->size)) {
|
||||
if (ggml_allocr_is_own(alloc, parent) == false) {
|
||||
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
|
||||
continue;
|
||||
}
|
||||
@ -427,7 +467,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
||||
struct hash_node * p_hn = hash_get(ht, parent);
|
||||
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
||||
if (ggml_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = get_view_source(parent);
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
||||
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
||||
@ -453,7 +493,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
|
||||
}
|
||||
}
|
||||
|
||||
static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
static size_t ggml_allocr_alloc_graph_tensors_n(
|
||||
struct ggml_allocr * alloc,
|
||||
struct ggml_cgraph ** graphs, int n_graphs,
|
||||
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
||||
@ -469,7 +509,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
if (ggml_is_view(node)) {
|
||||
struct ggml_tensor * view_src = get_view_source(node);
|
||||
struct ggml_tensor * view_src = node->view_src;
|
||||
hash_get(ht, view_src)->n_views += 1;
|
||||
}
|
||||
|
||||
@ -531,11 +571,10 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
AT_PRINTF("\n");
|
||||
}
|
||||
|
||||
|
||||
// update parents
|
||||
// update immediately if there is no parse_seq
|
||||
// update only at barriers if there is parse_seq
|
||||
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
|
||||
if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
|
||||
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
|
||||
int update_end = alloc->parse_seq_len ? ind : ind + 1;
|
||||
for (int i = update_start; i < update_end; i++) {
|
||||
@ -554,17 +593,17 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
|
||||
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = get_view_source(parent);
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||
view_src_hn->n_views -= 1;
|
||||
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
|
||||
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
||||
ggml_allocator_free_tensor(alloc, view_src);
|
||||
ggml_allocr_free_tensor(alloc, view_src);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (parent->data != node->data) {
|
||||
ggml_allocator_free_tensor(alloc, parent);
|
||||
ggml_allocr_free_tensor(alloc, parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -581,7 +620,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
for (int i = 0; outputs[g][i] != NULL; i++) {
|
||||
struct ggml_tensor * output = outputs[g][i];
|
||||
AT_PRINTF("output: %s\n", output->name);
|
||||
ggml_allocator_free_tensor(alloc, output);
|
||||
ggml_allocr_free_tensor(alloc, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -590,5 +629,5 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
||||
}
|
||||
|
||||
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
||||
return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
||||
return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
||||
}
|
||||
|
32
ggml-cuda.cu
32
ggml-cuda.cu
@ -4086,7 +4086,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
|
||||
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
|
||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
|
||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int half_n_dims = ncols/4;
|
||||
|
||||
@ -4098,8 +4099,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
||||
const int i = row*ncols + col;
|
||||
|
||||
const float col_theta_scale = powf(theta_scale, col);
|
||||
const float p = p0 + p_delta*(row/p_delta_rows);
|
||||
|
||||
const float theta = p*col_theta_scale;
|
||||
const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
|
||||
const float sin_theta = sinf(theta);
|
||||
const float cos_theta = cosf(theta);
|
||||
|
||||
@ -4109,7 +4111,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
|
||||
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
const float block_theta = block_p*col_theta_scale;
|
||||
const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
|
||||
const float sin_block_theta = sinf(block_theta);
|
||||
const float cos_block_theta = cosf(block_theta);
|
||||
|
||||
@ -4984,12 +4986,13 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
|
||||
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
|
||||
}
|
||||
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
|
||||
GGML_ASSERT(nrows % 4 == 0);
|
||||
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
|
||||
const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
|
||||
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
|
||||
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 4 == 0);
|
||||
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
|
||||
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
|
||||
const dim3 block_nums(num_blocks_x, nrows, 1);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
|
||||
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
|
||||
}
|
||||
|
||||
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
|
||||
@ -5723,22 +5726,18 @@ inline void ggml_cuda_op_rope(
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
// compute
|
||||
if (is_glm) {
|
||||
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
|
||||
const float id_p = min(p, n_ctx - 2.f);
|
||||
const float block_p = max(p - (n_ctx - 2.f), 0.f);
|
||||
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
|
||||
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
|
||||
} else if (is_neox) {
|
||||
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
|
||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||
} else {
|
||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||
}
|
||||
|
||||
@ -6400,10 +6399,7 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
|
||||
}
|
||||
|
||||
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
113
ggml-metal.m
113
ggml-metal.m
@ -63,7 +63,10 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(relu);
|
||||
GGML_METAL_DECL_KERNEL(gelu);
|
||||
GGML_METAL_DECL_KERNEL(soft_max);
|
||||
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_f32);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
||||
@ -77,6 +80,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
||||
@ -117,14 +121,17 @@ static NSString * const msl_library_source = @"see metal.metal";
|
||||
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
metal_printf("%s: allocating\n", __func__);
|
||||
|
||||
// Show all the Metal device instances in the system
|
||||
NSArray * devices = MTLCopyAllDevices();
|
||||
id <MTLDevice> device;
|
||||
NSString * s;
|
||||
|
||||
#if TARGET_OS_OSX
|
||||
// Show all the Metal device instances in the system
|
||||
NSArray * devices = MTLCopyAllDevices();
|
||||
for (device in devices) {
|
||||
s = [device name];
|
||||
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Pick and show default Metal device
|
||||
device = MTLCreateSystemDefaultDevice();
|
||||
@ -139,14 +146,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
ctx->n_buffers = 0;
|
||||
ctx->concur_list_len = 0;
|
||||
|
||||
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
||||
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
#if 0
|
||||
// compile from source string and show compile log
|
||||
#ifdef GGML_SWIFT
|
||||
// load the default.metallib file
|
||||
{
|
||||
NSError * error = nil;
|
||||
|
||||
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
||||
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
||||
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
||||
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
||||
|
||||
// Load the metallib file into a Metal library
|
||||
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
||||
|
||||
if (error) {
|
||||
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
@ -161,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
|
||||
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
||||
|
||||
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
||||
@ -207,7 +222,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(relu);
|
||||
GGML_METAL_ADD_KERNEL(gelu);
|
||||
GGML_METAL_ADD_KERNEL(soft_max);
|
||||
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_f32);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
||||
@ -221,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
||||
@ -247,13 +266,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
#undef GGML_METAL_ADD_KERNEL
|
||||
}
|
||||
|
||||
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||
#if TARGET_OS_OSX
|
||||
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
if (ctx->device.maxTransferRate != 0) {
|
||||
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||
} else {
|
||||
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
||||
}
|
||||
#endif
|
||||
|
||||
return ctx;
|
||||
}
|
||||
@ -273,7 +294,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(relu);
|
||||
GGML_METAL_DEL_KERNEL(gelu);
|
||||
GGML_METAL_DEL_KERNEL(soft_max);
|
||||
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_f32);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
||||
@ -287,6 +311,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
||||
@ -327,7 +352,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
|
||||
void * ggml_metal_host_malloc(size_t n) {
|
||||
void * data = NULL;
|
||||
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
||||
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
||||
if (result != 0) {
|
||||
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
||||
return NULL;
|
||||
@ -365,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||
|
||||
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
||||
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
||||
*offs = (size_t) ioffs;
|
||||
|
||||
@ -401,7 +427,7 @@ bool ggml_metal_add_buffer(
|
||||
}
|
||||
}
|
||||
|
||||
const size_t size_page = getpagesize();
|
||||
const size_t size_page = sysconf(_SC_PAGESIZE);
|
||||
|
||||
size_t size_aligned = size;
|
||||
if ((size_aligned % size_page) != 0) {
|
||||
@ -454,6 +480,7 @@ bool ggml_metal_add_buffer(
|
||||
}
|
||||
}
|
||||
|
||||
#if TARGET_OS_OSX
|
||||
metal_printf(", (%8.2f / %8.2f)",
|
||||
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
@ -463,6 +490,9 @@ bool ggml_metal_add_buffer(
|
||||
} else {
|
||||
metal_printf("\n");
|
||||
}
|
||||
#else
|
||||
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
||||
#endif
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -698,6 +728,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
// utilize float4
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
@ -705,6 +736,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||
@ -721,6 +753,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
// utilize float4
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
@ -728,6 +761,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul];
|
||||
@ -743,6 +777,8 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const float scale = *(const float *) src1->data;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||
@ -750,7 +786,7 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@ -762,7 +798,7 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@ -782,7 +818,7 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@ -796,13 +832,16 @@ void ggml_metal_graph_compute(
|
||||
{
|
||||
const int nth = 32;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||
if (ne00%4 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||
}
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
@ -810,14 +849,23 @@ void ggml_metal_graph_compute(
|
||||
{
|
||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||
if (ne00%8 == 0) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||
}
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
if (ne00%8 == 0) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
}
|
||||
else {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
@ -830,8 +878,8 @@ void ggml_metal_graph_compute(
|
||||
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
if (!ggml_is_transposed(src0) &&
|
||||
!ggml_is_transposed(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
ne00%32 == 0 &&
|
||||
@ -856,14 +904,18 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
int nrows = 1;
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
@ -873,8 +925,14 @@ void ggml_metal_graph_compute(
|
||||
nth1 = 1;
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||
//} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
} else if (false) {
|
||||
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||
nrows = ne11;
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||
nrows = 4;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
@ -995,7 +1053,7 @@ void ggml_metal_graph_compute(
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
int64_t ny = (ne11 + 3)/4;
|
||||
int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
@ -1003,6 +1061,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
||||
@ -1018,9 +1077,9 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
@ -1141,7 +1200,7 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
||||
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
|
544
ggml-metal.metal
544
ggml-metal.metal
@ -38,7 +38,7 @@ kernel void kernel_add_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb,
|
||||
constant int64_t & nb,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
@ -63,18 +63,18 @@ kernel void kernel_mul_row(
|
||||
}
|
||||
|
||||
kernel void kernel_scale(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant float & scale,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * scale;
|
||||
}
|
||||
|
||||
kernel void kernel_silu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
float x = src0[tpig];
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
kernel void kernel_gelu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
float x = src0[tpig];
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
// BEWARE !!!
|
||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
@ -119,64 +118,70 @@ kernel void kernel_soft_max(
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
// parallel max
|
||||
buf[tpitg[0]] = -INFINITY;
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
|
||||
float lmax = psrc0[tpitg[0]];
|
||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
lmax = MAX(lmax, psrc0[i00]);
|
||||
}
|
||||
|
||||
// reduce
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
||||
if (tpitg[0] < i) {
|
||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
||||
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
||||
//if (tpitg[0] == 0) {
|
||||
// buf[0] = buf[0];
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float max = buf[0];
|
||||
const float max = simd_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
buf[tpitg[0]] = 0.0f;
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||
buf[tpitg[0]] += exp_psrc0;
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// whish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// reduce
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
||||
if (tpitg[0] < i) {
|
||||
buf[tpitg[0]] += buf[tpitg[0] + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// broadcast - not needed, see above
|
||||
//// broadcast
|
||||
//if (tpitg[0] == 0) {
|
||||
// buf[0] = buf[0];
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float sum = buf[0];
|
||||
const float sum = simd_sum(lsum);
|
||||
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
pdst[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_soft_max_4(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = psrc4[tpitg[0]];
|
||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||
}
|
||||
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
||||
const float max = simd_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
|
||||
const float sum = simd_sum(lsum);
|
||||
|
||||
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||
pdst4[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_mask_inf(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
|
||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
||||
} else {
|
||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_diag_mask_inf_8(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int & n_past,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
|
||||
const int64_t i = 2*tpig[0];
|
||||
|
||||
dst[i+0] = src0[i+0];
|
||||
dst[i+1] = src0[i+1];
|
||||
int64_t i4 = 4*i;
|
||||
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
||||
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
||||
const int64_t i00 = i4;
|
||||
for (int k = 3; k >= 0; --k) {
|
||||
if (i00 + 4 + k <= n_past + i01) {
|
||||
break;
|
||||
}
|
||||
dst[i+1][k] = -INFINITY;
|
||||
if (i00 + k > n_past + i01) {
|
||||
dst[i][k] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -220,14 +252,10 @@ kernel void kernel_norm(
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
//// broadcast
|
||||
//if (tpitg == 0) {
|
||||
// sum[0] /= ne00;
|
||||
//}
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
const float mean = sum[0];
|
||||
const float mean = sum[0] / ne00;
|
||||
|
||||
// recenter and VARIANCE
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
device float * y = dst + tgpig*ne00;
|
||||
sum[tpitg] = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
@ -235,12 +263,6 @@ kernel void kernel_norm(
|
||||
sum[tpitg] += y[i00] * y[i00];
|
||||
}
|
||||
|
||||
//// VARIANCE
|
||||
//// parallel sum
|
||||
//sum[tpitg] = 0.0f;
|
||||
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
// sum[tpitg] += y[i00] * y[i00];
|
||||
//}
|
||||
// reduce
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||
@ -249,12 +271,7 @@ kernel void kernel_norm(
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
//// broadcast
|
||||
//if (tpitg == 0) {
|
||||
// sum[0] /= ne00;
|
||||
//}
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
const float variance = sum[0];
|
||||
const float variance = sum[0] / ne00;
|
||||
|
||||
const float scale = 1.0f/sqrt(variance + eps);
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
@ -262,7 +279,6 @@ kernel void kernel_norm(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_rms_norm(
|
||||
device const void * src0,
|
||||
device float * dst,
|
||||
@ -630,7 +646,49 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
kernel void kernel_mul_mat_f16_f32_l4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int nrows = ne11;
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
|
||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_alibi_f32(
|
||||
@ -699,25 +757,27 @@ kernel void kernel_rope(
|
||||
constant int & mode,
|
||||
constant float & freq_base,
|
||||
constant float & freq_scale,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
const int64_t i3 = tpig[2];
|
||||
const int64_t i2 = tpig[1];
|
||||
const int64_t i1 = tpig[0];
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int64_t i3 = tgpig[2];
|
||||
const int64_t i2 = tgpig[1];
|
||||
const int64_t i1 = tgpig[0];
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
||||
|
||||
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
||||
|
||||
float theta = freq_scale * (float)p;
|
||||
const float theta_0 = freq_scale * (float)p;
|
||||
const float inv_ndims = -1.f/n_dims;
|
||||
|
||||
if (!is_neox) {
|
||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
||||
const float cos_theta = cos(theta);
|
||||
const float sin_theta = sin(theta);
|
||||
|
||||
theta *= theta_scale;
|
||||
|
||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
@ -729,12 +789,12 @@ kernel void kernel_rope(
|
||||
}
|
||||
} else {
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
||||
const float cos_theta = cos(theta);
|
||||
const float sin_theta = sin(theta);
|
||||
|
||||
theta *= theta_scale;
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
@ -1138,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
||||
|
||||
float yl[16];
|
||||
float yl[32];
|
||||
|
||||
const uint16_t kmask1 = 0x0303;
|
||||
const uint16_t kmask1 = 0x3030;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
||||
const int tid = tiisg/2;
|
||||
const int ix = tiisg%2;
|
||||
const int ip = tid/8; // 0 or 1
|
||||
const int il = tid/2 - 4*ip; // 0...3
|
||||
const int tid = tiisg/4;
|
||||
const int ix = tiisg%4;
|
||||
const int ip = tid/4; // 0 or 1
|
||||
const int il = 2*((tid%4)/2); // 0 or 2
|
||||
const int ir = tid%2;
|
||||
const int n = 8;
|
||||
const int l0 = n*ir;
|
||||
|
||||
const uint16_t m1 = 1 << (4*ip + il);
|
||||
const uint16_t m2 = m1 << 8;
|
||||
// One would think that the Metal compiler would figure out that ip and il can only have
|
||||
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
||||
// with these two tales.
|
||||
//
|
||||
// Possible masks for the high bit
|
||||
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
|
||||
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
|
||||
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
|
||||
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
|
||||
|
||||
// Possible masks for the low 2 bits
|
||||
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
|
||||
|
||||
const ushort4 hm = mm[2*ip + il/2];
|
||||
|
||||
const int shift = 2*il;
|
||||
const uint16_t qm1 = 0x0003 << shift;
|
||||
const uint16_t qm2 = 0x0300 << shift;
|
||||
const int32_t v1 = 4 << shift;
|
||||
const int32_t v2 = 1024 << shift;
|
||||
const float v1 = il == 0 ? 4.f : 64.f;
|
||||
const float v2 = 4.f * v1;
|
||||
|
||||
const uint16_t s_shift1 = 4*ip;
|
||||
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
||||
const int ik = 4 + (il%2);
|
||||
const uint16_t s_shift2 = s_shift1 + il;
|
||||
|
||||
const int q_offset = 32*ip + l0;
|
||||
const int y_offset = 128*ip + 32*il + l0;
|
||||
@ -1171,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
|
||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||
|
||||
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
||||
for (int i = ix; i < nb; i += 2) {
|
||||
uint32_t scales32, aux32;
|
||||
thread uint16_t * scales16 = (thread uint16_t *)&scales32;
|
||||
thread const int8_t * scales = (thread const int8_t *)&scales32;
|
||||
|
||||
float sumf1[2] = {0.f};
|
||||
float sumf2[2] = {0.f};
|
||||
for (int i = ix; i < nb; i += 4) {
|
||||
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
yl[l+0] = y1[l+ 0];
|
||||
yl[l+8] = y1[l+16];
|
||||
yl[l+ 0] = y1[l+ 0];
|
||||
yl[l+ 8] = y1[l+16];
|
||||
yl[l+16] = y1[l+32];
|
||||
yl[l+24] = y1[l+48];
|
||||
}
|
||||
|
||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
||||
@ -1187,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
const float d_all = (float)dh[0];
|
||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||
|
||||
float s1 = 0, s2 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint16_t qs = q[l/2];
|
||||
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
||||
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
||||
}
|
||||
float d = d_all * (s1 + 1.f/256.f * s2);
|
||||
sumf1[row] += d * scales[0];
|
||||
sumf2[row] += d;
|
||||
scales16[0] = a[4];
|
||||
scales16[1] = a[5];
|
||||
aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
|
||||
scales16[0] = a[il+0];
|
||||
scales16[1] = a[il+1];
|
||||
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
||||
|
||||
s1 = s2 = 0;
|
||||
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint16_t qs = q[l/2+8];
|
||||
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
||||
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
||||
const int32_t qs = q[l/2];
|
||||
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
||||
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
||||
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
|
||||
s4 += yl[l+16] * (qs & qm[il/2][2]);
|
||||
s5 += yl[l+17] * (qs & qm[il/2][3]);
|
||||
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
|
||||
}
|
||||
d = d_all * (s1 + 1.f/256.f * s2);
|
||||
sumf1[row] += d * scales[1];
|
||||
sumf2[row] += d;
|
||||
float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||
float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
||||
sumf1[row] += d1 * (scales[0] - 32);
|
||||
sumf2[row] += d2 * (scales[2] - 32);
|
||||
|
||||
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const int32_t qs = q[l/2+8];
|
||||
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
||||
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
||||
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
|
||||
s4 += yl[l+24] * (qs & qm[il/2][2]);
|
||||
s5 += yl[l+25] * (qs & qm[il/2][3]);
|
||||
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
|
||||
}
|
||||
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
|
||||
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
|
||||
sumf1[row] += d1 * (scales[1] - 32);
|
||||
sumf2[row] += d2 * (scales[3] - 32);
|
||||
|
||||
q += step;
|
||||
h += step;
|
||||
@ -1216,15 +1308,17 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||
|
||||
}
|
||||
|
||||
y1 += 2 * QK_K;
|
||||
y1 += 4 * QK_K;
|
||||
|
||||
}
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
||||
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
||||
sumf1[row] = simd_sum(sumf);
|
||||
}
|
||||
if (tiisg == 0) {
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1579,17 +1673,25 @@ kernel void kernel_mul_mat_q5_K_f32(
|
||||
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
||||
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
||||
|
||||
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc1 = {0.f};
|
||||
float4 acc2 = {0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
uint8_t h = qh[l];
|
||||
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
||||
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
||||
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
||||
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
||||
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||
acc1[2] += yh[l+0] * (q2[l] & 0x0F);
|
||||
acc1[3] += yh[l+8] * (q2[l] & 0xF0);
|
||||
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
|
||||
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
|
||||
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
||||
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
||||
}
|
||||
const float dall = dh[0];
|
||||
const float dmin = dh[1];
|
||||
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
||||
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
||||
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
||||
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
||||
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
||||
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
||||
|
||||
q1 += step;
|
||||
@ -1762,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||
|
||||
//============================= templates and their specializations =============================
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
float4x4 temp = *(((device float4x4 *)src));
|
||||
for (int i = 0; i < 16; i++){
|
||||
reg[i/4][i%4] = temp[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
half4x4 temp = *(((device half4x4 *)src));
|
||||
@ -1773,28 +1884,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
||||
const half m = il ? ( -8.h * 16.h) : -8.h;
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float md = -8.h * xb->d;
|
||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
for (int i=0;i<8;i++) {
|
||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
||||
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
||||
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
const half d = il ? (xb->d / 16.h) : xb->d;
|
||||
const half m = xb->m;
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float m = xb->m;
|
||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
for (int i=0;i<8;i++) {
|
||||
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
||||
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
||||
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
|
||||
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1830,7 +1943,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d_all = (float)(xb->d);
|
||||
const half d_all = xb->d;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
@ -1843,16 +1956,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
||||
((il/4)>0 ? 12 : 3);
|
||||
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
||||
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
||||
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
|
||||
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
||||
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
||||
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
||||
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
||||
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
|
||||
const half ml = 4.h * dl;
|
||||
|
||||
il = (il/2)%4;
|
||||
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
il = (il/2) & 3;
|
||||
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
dl *= coef;
|
||||
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
||||
}
|
||||
#else
|
||||
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
||||
@ -1867,26 +1982,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
||||
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
||||
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * q = xb->qs;
|
||||
device const uchar * q = xb->qs;
|
||||
|
||||
#if QK_K == 256
|
||||
const float d = (float)(xb->d);
|
||||
const float min = (float)(xb->dmin);
|
||||
short is = (il/4) * 2;
|
||||
q = q + (il/4) * 32 + 16 * (il&1);
|
||||
il = il%4;
|
||||
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
||||
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
||||
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const half min = xb->dmin;
|
||||
const half dl = d * sc[0];
|
||||
const half ml = min * sc[1];
|
||||
#else
|
||||
q = q + 16 * (il&1);
|
||||
device const uint8_t * s = xb->scales;
|
||||
device const half2 * dh = (device const half2 *)xb->d;
|
||||
const float2 d = (float2)dh[0];
|
||||
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
||||
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
|
||||
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
|
||||
#endif
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
@ -1900,19 +2020,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
||||
device const uint8_t * qh = xb->qh;
|
||||
|
||||
#if QK_K == 256
|
||||
const float d = (float)(xb->d);
|
||||
const float min = (float)(xb->dmin);
|
||||
short is = (il/4) * 2;
|
||||
q = q + 32 * (il/4) + 16 * (il&1);
|
||||
qh = qh + 16 * (il&1);
|
||||
uint8_t ul = 1 << (il/2);
|
||||
il = il%4;
|
||||
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
||||
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
||||
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const half d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const half min = xb->dmin;
|
||||
const half dl = d * sc[0];
|
||||
const half ml = min * sc[1];
|
||||
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const float qh_val = il<2 ? 16.f : 256.f;
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const half qh_val = il<2 ? 16.h : 256.h;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||
}
|
||||
@ -1931,7 +2051,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d_all = (float)(xb->d);
|
||||
const half d_all = xb->d;
|
||||
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
||||
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
@ -1939,19 +2059,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
|
||||
#if QK_K == 256
|
||||
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
||||
qh = qh + 32*(il/8) + 16*(il&1);
|
||||
float sc = scales[(il%2) + 2 * ((il/2))];
|
||||
il = (il/2)%4;
|
||||
half sc = scales[(il%2) + 2 * ((il/2))];
|
||||
il = (il/2) & 3;
|
||||
#else
|
||||
ql = ql + 16 * (il&1);
|
||||
float sc = scales[il];
|
||||
half sc = scales[il];
|
||||
#endif
|
||||
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
||||
const half coef = il>1 ? 1.f/16.h : 1.h;
|
||||
const half ml = d_all * sc * 32.h;
|
||||
const half dl = d_all * sc * coef;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
||||
const float coef = il>1 ? 1.f/16.f : 1.f;
|
||||
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
||||
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
||||
reg[i/4][i%4] = d_all * sc * q * coef;
|
||||
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
||||
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
||||
reg[i/4][i%4] = dl * q - ml;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1991,22 +2113,25 @@ kernel void kernel_get_rows(
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & gqa,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & nb10,
|
||||
constant int64_t & nb11,
|
||||
constant int64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & gqa,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = ((threadgroup half *)shared_memory);
|
||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
|
||||
const uint r0 = tgpig.y;
|
||||
@ -2019,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
simdgroup_float8x8 c_res[8];
|
||||
for (int i = 0; i < 8; i++){
|
||||
@ -2027,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
}
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
||||
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
|
||||
|
||||
uint offset0 = im/gqa*nb02;
|
||||
ushort offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||
device const float * y = (device const float *)(src1
|
||||
+ nb12 * im
|
||||
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
//load data and store to threadgroup memory
|
||||
@ -2110,6 +2240,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
||||
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||
@ -2120,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
|
||||
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
||||
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
||||
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
||||
typedef void (mat_mm_t)(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & nb10,
|
||||
constant int64_t & nb11,
|
||||
constant int64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & gqa,
|
||||
threadgroup uchar *, uint3, uint, uint);
|
||||
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
|
69
ggml.c
69
ggml.c
@ -1,4 +1,3 @@
|
||||
#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
|
||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
||||
|
||||
#include "ggml.h"
|
||||
@ -107,6 +106,9 @@ typedef void * thread_ret_t;
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#endif
|
||||
#ifdef GGML_USE_CPU_HBM
|
||||
#include <hbwmalloc.h>
|
||||
#endif
|
||||
|
||||
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
||||
@ -196,9 +198,15 @@ typedef void * thread_ret_t;
|
||||
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
|
||||
#else
|
||||
inline static void * ggml_aligned_malloc(size_t size) {
|
||||
if (size == 0) {
|
||||
GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
|
||||
return NULL;
|
||||
}
|
||||
void * aligned_memory = NULL;
|
||||
#ifdef GGML_USE_METAL
|
||||
int result = posix_memalign(&aligned_memory, getpagesize(), size);
|
||||
#ifdef GGML_USE_CPU_HBM
|
||||
int result = hbw_posix_memalign(&aligned_memory, 16, size);
|
||||
#elif GGML_USE_METAL
|
||||
int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
|
||||
#else
|
||||
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
|
||||
#endif
|
||||
@ -219,8 +227,12 @@ inline static void * ggml_aligned_malloc(size_t size) {
|
||||
return aligned_memory;
|
||||
}
|
||||
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
|
||||
#ifdef GGML_USE_CPU_HBM
|
||||
#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
|
||||
#else
|
||||
#define GGML_ALIGNED_FREE(ptr) free(ptr)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define UNUSED GGML_UNUSED
|
||||
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
||||
@ -4291,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
|
||||
}
|
||||
|
||||
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||
size_t nbytes;
|
||||
size_t blck_size = ggml_blck_size(tensor->type);
|
||||
if (blck_size == 1) {
|
||||
nbytes = ggml_type_size(tensor->type);
|
||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||
}
|
||||
}
|
||||
else {
|
||||
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||
}
|
||||
}
|
||||
|
||||
return nbytes;
|
||||
}
|
||||
|
||||
@ -4572,6 +4595,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// allow to call ggml_init with 0 size
|
||||
if (params.mem_size == 0) {
|
||||
params.mem_size = GGML_MEM_ALIGN;
|
||||
}
|
||||
|
||||
const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
|
||||
|
||||
*ctx = (struct ggml_context) {
|
||||
@ -4774,7 +4802,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
||||
|
||||
size_t obj_alloc_size = 0;
|
||||
|
||||
if (view_src == NULL && ctx->no_alloc == false) {
|
||||
if (view_src == NULL && !ctx->no_alloc) {
|
||||
if (ctx->scratch.data != NULL) {
|
||||
// allocate tensor data in the scratch buffer
|
||||
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
|
||||
@ -5475,7 +5503,7 @@ static struct ggml_tensor * ggml_mul_impl(
|
||||
}
|
||||
|
||||
if (inplace) {
|
||||
GGML_ASSERT(is_node == false);
|
||||
GGML_ASSERT(!is_node);
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
@ -5518,7 +5546,7 @@ static struct ggml_tensor * ggml_div_impl(
|
||||
}
|
||||
|
||||
if (inplace) {
|
||||
GGML_ASSERT(is_node == false);
|
||||
GGML_ASSERT(!is_node);
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
@ -17266,10 +17294,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
} else {
|
||||
// wait for other threads to finish
|
||||
const int last = node_n;
|
||||
do {
|
||||
//sched_yield();
|
||||
while (true) {
|
||||
// TODO: this sched_yield can have significant impact on the performance - either positive or negative
|
||||
// depending on the workload and the operating system.
|
||||
// since it is not clear what is the best approach, it should potentially become user-configurable
|
||||
// ref: https://github.com/ggerganov/ggml/issues/291
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
sched_yield();
|
||||
#endif
|
||||
|
||||
node_n = atomic_load(&state->shared->node_n);
|
||||
} while (node_n == last);
|
||||
if (node_n != last) break;
|
||||
};
|
||||
}
|
||||
|
||||
// check if we should stop
|
||||
@ -18320,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
struct ggml_tensor * node = cgraph->leafs[i];
|
||||
|
||||
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
|
||||
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
|
||||
i,
|
||||
node->ne[0], node->ne[1],
|
||||
ggml_op_name(node->op));
|
||||
ggml_op_name(node->op),
|
||||
ggml_get_name(node));
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
||||
@ -19962,7 +19999,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
|
||||
struct ggml_tensor * data = NULL;
|
||||
|
||||
if (params.no_alloc == false) {
|
||||
if (!params.no_alloc) {
|
||||
data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
|
||||
|
||||
ok = ok && data != NULL;
|
||||
@ -20003,7 +20040,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
}
|
||||
|
||||
// point the data member to the appropriate location in the binary blob using the tensor infos
|
||||
if (params.no_alloc == false) {
|
||||
if (!params.no_alloc) {
|
||||
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
|
||||
cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
|
||||
}
|
||||
|
@ -1,57 +0,0 @@
|
||||
# - "turn on lights."
|
||||
# - "set thermostat to 22."
|
||||
# - "increase TV by 10."
|
||||
# - "decrease oven by 50."
|
||||
# - "play music."
|
||||
# - "stop podcast."
|
||||
# - "schedule cleaning at 3pm."
|
||||
# - "cancel cleaning."
|
||||
# - "remind me to buy milk at 5pm."
|
||||
# - "show me security system."
|
||||
# - "hide washing machine."
|
||||
# - "what is the lights status?"
|
||||
# - "what is the current thermostat value?"
|
||||
# - "what is the security system status?"
|
||||
# - "what is the door lock status?"
|
||||
# - "what is the camera battery level?"
|
||||
# - "what is the weather like today?"
|
||||
# - "what is the forecast for tomorrow?"
|
||||
# - "what is the time?"
|
||||
# - "what is my schedule for today?"
|
||||
# - "what tasks do I have?"
|
||||
# - "what reminders do I have?"
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
|
||||
#
|
||||
|
||||
root ::= init " " (command | question) "."
|
||||
prompt ::= init
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " Ok Whisper, start listening for commands."
|
||||
|
||||
command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
|
||||
"Increase " device " by " value | "Decrease " device " by " value |
|
||||
"Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
|
||||
"Remind me to " task " at " time | "Show me " device | "Hide " device
|
||||
|
||||
question ::= "What is the " device " status?" | "What is the current " device " value?" |
|
||||
"What is the " device " temperature?" | "What is the " device " humidity?" |
|
||||
"What is the " device " power consumption?" | "What is the " device " battery level?" |
|
||||
"What is the weather like today?" | "What is the forecast for tomorrow?" |
|
||||
"What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
|
||||
"What reminders do I have?"
|
||||
|
||||
device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
|
||||
"music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
|
||||
"vacuum cleaner"
|
||||
|
||||
value ::= [0-9]+
|
||||
|
||||
media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"
|
||||
|
||||
task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?
|
||||
|
||||
time ::= [0-9] [0-9]? ("am" | "pm")?
|
@ -1,29 +0,0 @@
|
||||
# - bishop to c3
|
||||
# - rook to d4
|
||||
# - knight to e5
|
||||
# - d4 d5 knight to c3
|
||||
# - c3 queen to d4 king b1
|
||||
# - pawn to a1 bishop to b2 knight to c3
|
||||
#
|
||||
# The prompt (--prompt) is the initial phrase that the user has to say.
|
||||
# This is used to prime Whisper with how the user is expected to speak.
|
||||
#
|
||||
# Provide long context (--context) with sample moves to help Whisper decode the correct sequence.
|
||||
# Longer context is better, but it slightly increases the processing time.
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100
|
||||
#
|
||||
|
||||
root ::= init move move? move? "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " rook to b4, f3"
|
||||
|
||||
move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
||||
|
||||
piece ::= "bishop" | "rook" | "knight" | "queen"
|
||||
king ::= "king"
|
||||
pawn ::= "pawn"
|
@ -1,16 +0,0 @@
|
||||
# - red
|
||||
# - green
|
||||
# - blue
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue,"
|
||||
#
|
||||
|
||||
root ::= init color "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " red, green, blue"
|
||||
|
||||
color ::= ", " ("red" | "green" | "blue")
|
@ -22,7 +22,28 @@ function get_script_path() {
|
||||
models_path="$(get_script_path)"
|
||||
|
||||
# Whisper models
|
||||
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small.en-tdrz" "small" "medium.en" "medium" "large-v1" "large" )
|
||||
models=(
|
||||
"tiny.en"
|
||||
"tiny"
|
||||
"tiny-q5_1"
|
||||
"tiny.en-q5_1"
|
||||
"base.en"
|
||||
"base"
|
||||
"base-q5_1"
|
||||
"base.en-q5_1"
|
||||
"small.en"
|
||||
"small.en-tdrz"
|
||||
"small"
|
||||
"small-q5_1"
|
||||
"small.en-q5_1"
|
||||
"medium"
|
||||
"medium.en"
|
||||
"medium-q5_0"
|
||||
"medium.en-q5_0"
|
||||
"large-v1"
|
||||
"large"
|
||||
"large-q5_0"
|
||||
)
|
||||
|
||||
# list available models
|
||||
function list_models {
|
||||
|
2815
whisper.cpp
2815
whisper.cpp
File diff suppressed because it is too large
Load Diff
37
whisper.h
37
whisper.h
@ -96,37 +96,6 @@ extern "C" {
|
||||
void (*close)(void * ctx);
|
||||
} whisper_model_loader;
|
||||
|
||||
// grammar element type
|
||||
enum whisper_gretype {
|
||||
// end of rule definition
|
||||
WHISPER_GRETYPE_END = 0,
|
||||
|
||||
// start of alternate definition for rule
|
||||
WHISPER_GRETYPE_ALT = 1,
|
||||
|
||||
// non-terminal element: reference to rule
|
||||
WHISPER_GRETYPE_RULE_REF = 2,
|
||||
|
||||
// terminal element: character (code point)
|
||||
WHISPER_GRETYPE_CHAR = 3,
|
||||
|
||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||
WHISPER_GRETYPE_CHAR_NOT = 4,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||
// be an inclusive range ([a-z])
|
||||
WHISPER_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||
|
||||
// modifies a preceding WHISPER_GRETYPE_CHAR or
|
||||
// WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||
WHISPER_GRETYPE_CHAR_ALT = 6,
|
||||
};
|
||||
|
||||
typedef struct whisper_grammar_element {
|
||||
enum whisper_gretype type;
|
||||
uint32_t value; // Unicode code point or rule ID
|
||||
} whisper_grammar_element;
|
||||
|
||||
// Various functions for loading a ggml whisper model.
|
||||
// Allocate (almost) all memory needed for the model.
|
||||
// Return NULL on failure
|
||||
@ -389,7 +358,6 @@ extern "C" {
|
||||
|
||||
bool translate;
|
||||
bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
bool no_timestamps; // do not generate timestamps
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
@ -463,11 +431,6 @@ extern "C" {
|
||||
// called by each decoder to filter obtained logits
|
||||
whisper_logits_filter_callback logits_filter_callback;
|
||||
void * logits_filter_callback_user_data;
|
||||
|
||||
const whisper_grammar_element ** grammar_rules;
|
||||
size_t n_grammar_rules;
|
||||
size_t i_start_rule;
|
||||
float grammar_penalty;
|
||||
};
|
||||
|
||||
// NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params()
|
||||
|
Reference in New Issue
Block a user