From 128694213f223ccc362ded8bb66e5e4cd65e40cc Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 16 Oct 2023 21:46:29 +0200 Subject: [PATCH] feat: llama.cpp gRPC C++ backend (#1170) * wip: llama.cpp c++ gRPC server Signed-off-by: Ettore Di Giacinto * make it work, attach it to the build process Signed-off-by: Ettore Di Giacinto * update deps Signed-off-by: Ettore Di Giacinto * fix: add protobuf dep Signed-off-by: Ettore Di Giacinto * try fix protobuf on cmake * cmake: workarounds Signed-off-by: Ettore Di Giacinto * add packages * cmake: use fixed version of grpc Signed-off-by: Ettore Di Giacinto * cmake(grpc): install locally * install grpc Signed-off-by: Ettore Di Giacinto * install required deps for grpc on debian bullseye Signed-off-by: Ettore Di Giacinto * debug * debug * Fixups * no need to install cmake manually Signed-off-by: Ettore Di Giacinto * ci: fixup macOS * use brew whenever possible Signed-off-by: Ettore Di Giacinto * macOS fixups * debug * fix container build Signed-off-by: Ettore Di Giacinto * workaround * try mac https://stackoverflow.com/questions/23905661/on-mac-g-clang-fails-to-search-usr-local-include-and-usr-local-lib-by-def * Disable temp. arm64 docker image builds --------- Signed-off-by: Ettore Di Giacinto --- .github/workflows/bump_deps.yaml | 3 + .github/workflows/image.yml | 5 +- .github/workflows/release.yaml | 14 + .github/workflows/test.yml | 20 +- Dockerfile | 17 +- Makefile | 26 +- backend/cpp/llama/CMakeLists.txt | 57 ++ backend/cpp/llama/Makefile | 44 ++ backend/cpp/llama/grpc-server.cpp | 964 ++++++++++++++++++++++++++++++ pkg/model/initializers.go | 11 +- 10 files changed, 1145 insertions(+), 16 deletions(-) create mode 100644 backend/cpp/llama/CMakeLists.txt create mode 100644 backend/cpp/llama/Makefile create mode 100644 backend/cpp/llama/grpc-server.cpp diff --git a/.github/workflows/bump_deps.yaml b/.github/workflows/bump_deps.yaml index 4344ac2b..f8fd93d8 100644 --- a/.github/workflows/bump_deps.yaml +++ b/.github/workflows/bump_deps.yaml @@ -12,6 +12,9 @@ jobs: - repository: "go-skynet/go-llama.cpp" variable: "GOLLAMA_VERSION" branch: "master" + - repository: "ggerganov/llama.cpp" + variable: "CPPLLAMA_VERSION" + branch: "master" - repository: "go-skynet/go-ggml-transformers.cpp" variable: "GOGGMLTRANSFORMERS_VERSION" branch: "master" diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index a9a97c0c..952b5492 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -19,7 +19,8 @@ jobs: matrix: include: - build-type: '' - platforms: 'linux/amd64,linux/arm64' + #platforms: 'linux/amd64,linux/arm64' + platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '' ffmpeg: '' @@ -38,7 +39,7 @@ jobs: tag-suffix: '-cublas-cuda12' ffmpeg: '' - build-type: '' - platforms: 'linux/amd64,linux/arm64' + platforms: 'linux/amd64' tag-latest: 'false' tag-suffix: '-ffmpeg' ffmpeg: 'true' diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4a0e83e0..5e472f96 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -29,6 +29,12 @@ jobs: run: | sudo apt-get update sudo apt-get install build-essential ffmpeg + + git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + ../.. && sudo make -j12 install + - name: Build id: build env: @@ -66,12 +72,20 @@ jobs: - uses: actions/setup-go@v4 with: go-version: '>=1.21.0' + - name: Dependencies + run: | + git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + ../.. && make -j12 install && rm -rf grpc - name: Build id: build env: CMAKE_ARGS: "${{ matrix.defines }}" BUILD_ID: "${{ matrix.build }}" run: | + export C_INCLUDE_PATH=/usr/local/include + export CPLUS_INCLUDE_PATH=/usr/local/include make dist - uses: actions/upload-artifact@v3 with: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6baecd1a..605b24f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -67,11 +67,15 @@ jobs: run: | sudo apt-get update sudo apt-get install build-essential ffmpeg - + sudo apt-get install -y ca-certificates cmake curl patch sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 sudo pip install -r extra/requirements.txt + + # Pre-build stable diffusion before we install a newever version of abseil (not compatible with stablediffusion-ncn) + GO_TAGS="tts stablediffusion" GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build + sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \ tar -xzvf - && \ @@ -87,6 +91,12 @@ jobs: sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ sudo ln -s /usr/lib/libpiper_phonemize.so /usr/lib/libpiper_phonemize.so.1 && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/ + + git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + ../.. && sudo make -j12 install + - name: Test run: | ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test @@ -108,6 +118,14 @@ jobs: # You can test your matrix by printing the current Go version - name: Display Go version run: go version + - name: Dependencies + run: | + git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + ../.. && make -j12 install && rm -rf grpc - name: Test run: | + export C_INCLUDE_PATH=/usr/local/include + export CPLUS_INCLUDE_PATH=/usr/local/include CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" make test \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index b6316b3f..431307a5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,8 @@ ENV GALLERIES='[{"name":"model-gallery", "url":"github:go-skynet/model-gallery/i ARG GO_TAGS="stablediffusion tts" RUN apt-get update && \ - apt-get install -y ca-certificates cmake curl patch pip + apt-get install -y ca-certificates curl patch pip cmake + # Use the variables in subsequent instructions RUN echo "Target Architecture: $TARGETARCH" @@ -104,6 +105,15 @@ RUN make prepare COPY . . COPY .git . +# stablediffusion does not tolerate a newer version of abseil, build it first +RUN GRPC_BACKENDS=backend-assets/grpc/stablediffusion make build + +RUN git clone --recurse-submodules -b v1.58.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \ + cd grpc && mkdir -p cmake/build && cd cmake/build && cmake -DgRPC_INSTALL=ON \ + -DgRPC_BUILD_TESTS=OFF \ + ../.. && make -j12 install && rm -rf grpc + +# Rebuild with defaults backends RUN ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build ################################### @@ -132,8 +142,13 @@ WORKDIR /build # https://github.com/go-skynet/LocalAI/pull/434 COPY . . RUN make prepare-sources + +# Copy the binary COPY --from=builder /build/local-ai ./ +# do not let piper rebuild (requires an older version of absl) +COPY --from=builder /build/backend-assets/grpc/piper ./backend-assets/grpc/piper + # Copy VALLE-X as it's not a real "lib" RUN cp -rfv /usr/lib/vall-e-x/* ./ diff --git a/Makefile b/Makefile index 8ad4c579..692417bd 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ GOLLAMA_VERSION?=1676dcd7a139b6cdfbaea5fd67f46dc25d9d8bcf GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7 +CPPLLAMA_VERSION?=24ba3d829e31a6eda3fa1723f692608c2fa3adda + # gpt4all version GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_VERSION?=27a8b020c36b0df8f8b82a252d261cda47cf44b8 @@ -120,7 +122,7 @@ ifeq ($(findstring tts,$(GO_TAGS)),tts) OPTIONAL_GRPC+=backend-assets/grpc/piper endif -GRPC_BACKENDS?=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) +GRPC_BACKENDS?=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/llama-cpp backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) .PHONY: all test build vendor @@ -223,7 +225,7 @@ go-llama/libbinding.a: go-llama go-llama-stable/libbinding.a: go-llama-stable $(MAKE) -C go-llama-stable BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a -go-piper/libpiper_binding.a: +go-piper/libpiper_binding.a: go-piper $(MAKE) -C go-piper libpiper_binding.a example/main get-sources: go-llama go-llama-stable go-ggllm go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion @@ -280,6 +282,7 @@ clean: ## Remove build related file rm -rf ./go-ggllm rm -rf $(BINARY_NAME) rm -rf release/ + $(MAKE) -C backend/cpp/llama clean ## Build: @@ -395,6 +398,16 @@ ifeq ($(BUILD_TYPE),metal) cp go-llama/build/bin/ggml-metal.metal backend-assets/grpc/ endif +backend/cpp/llama/grpc-server: + LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server + +backend-assets/grpc/llama-cpp: backend-assets/grpc backend/cpp/llama/grpc-server + cp -rfv backend/cpp/llama/grpc-server backend-assets/grpc/llama-cpp +# TODO: every binary should have its own folder instead, so can have different metal implementations +ifeq ($(BUILD_TYPE),metal) + cp backend/cpp/llama/llama.cpp/build/bin/ggml-metal.metal backend-assets/grpc/ +endif + backend-assets/grpc/llama-stable: backend-assets/grpc go-llama-stable/libbinding.a $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama-stable CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama-stable LIBRARY_PATH=$(shell pwd)/go-llama \ @@ -451,9 +464,12 @@ backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a backend-assets/grpc/langchain-huggingface: backend-assets/grpc $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/ -backend-assets/grpc/stablediffusion: backend-assets/grpc go-stable-diffusion/libstablediffusion.a - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \ - $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/ +backend-assets/grpc/stablediffusion: backend-assets/grpc + if [ ! -f backend-assets/grpc/stablediffusion ]; then \ + $(MAKE) go-stable-diffusion/libstablediffusion.a; \ + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/; \ + fi backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \ diff --git a/backend/cpp/llama/CMakeLists.txt b/backend/cpp/llama/CMakeLists.txt new file mode 100644 index 00000000..116283fb --- /dev/null +++ b/backend/cpp/llama/CMakeLists.txt @@ -0,0 +1,57 @@ +set(CMAKE_CXX_STANDARD 17) +cmake_minimum_required(VERSION 3.15) +set(TARGET grpc-server) +set(_PROTOBUF_LIBPROTOBUF libprotobuf) +set(_REFLECTION grpc++_reflection) + +find_package(absl CONFIG REQUIRED) +find_package(Protobuf CONFIG REQUIRED) +find_package(gRPC CONFIG REQUIRED) + +find_program(_PROTOBUF_PROTOC protoc) +set(_GRPC_GRPCPP grpc++) +find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${Protobuf_INCLUDE_DIRS}) + +message(STATUS "Using protobuf ${Protobuf_VERSION} ${Protobuf_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}") + + +# Proto file +get_filename_component(hw_proto "../../../../../../pkg/grpc/proto/backend.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) + +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.h") + +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") + +# hw_grpc_proto +add_library(hw_grpc_proto + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) + +add_executable(${TARGET} grpc-server.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto + absl::flags_parse + gRPC::${_REFLECTION} + gRPC::${_GRPC_GRPCPP} + protobuf::${_PROTOBUF_LIBPROTOBUF}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() \ No newline at end of file diff --git a/backend/cpp/llama/Makefile b/backend/cpp/llama/Makefile new file mode 100644 index 00000000..77851eaa --- /dev/null +++ b/backend/cpp/llama/Makefile @@ -0,0 +1,44 @@ + +LLAMA_VERSION?=24ba3d829e31a6eda3fa1723f692608c2fa3adda + +CMAKE_ARGS?= +BUILD_TYPE?= + +# If build type is cublas, then we set -DLLAMA_CUBLAS=ON to CMAKE_ARGS automatically +ifeq ($(BUILD_TYPE),cublas) + CMAKE_ARGS+=-DLLAMA_CUBLAS=ON +# If build type is openblas then we set -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS +# to CMAKE_ARGS automatically +else ifeq ($(BUILD_TYPE),openblas) + CMAKE_ARGS+=-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS +# If build type is clblast (openCL) we set -DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path +else ifeq ($(BUILD_TYPE),clblast) + CMAKE_ARGS+=-DLLAMA_CLBLAST=ON -DCLBlast_DIR=/some/path +# If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ +else ifeq ($(BUILD_TYPE),hipblas) + CMAKE_ARGS+=-DLLAMA_HIPBLAS=ON +endif + +llama.cpp: + git clone --recurse-submodules https://github.com/ggerganov/llama.cpp llama.cpp + cd llama.cpp && git checkout -b build $(LLAMA_VERSION) && git submodule update --init --recursive --depth 1 + +llama.cpp/examples/grpc-server: + mkdir -p llama.cpp/examples/grpc-server + cp -r $(abspath ./)/CMakeLists.txt llama.cpp/examples/grpc-server/ + cp -r $(abspath ./)/grpc-server.cpp llama.cpp/examples/grpc-server/ + echo "add_subdirectory(grpc-server)" >> llama.cpp/examples/CMakeLists.txt + +rebuild: + cp -rfv $(abspath ./)/CMakeLists.txt llama.cpp/examples/grpc-server/ + cp -rfv $(abspath ./)/grpc-server.cpp llama.cpp/examples/grpc-server/ + rm -rf grpc-server + $(MAKE) grpc-server + +clean: + rm -rf llama.cpp + rm -rf grpc-server + +grpc-server: llama.cpp llama.cpp/examples/grpc-server + cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release + cp llama.cpp/build/bin/grpc-server . \ No newline at end of file diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp new file mode 100644 index 00000000..16b63aff --- /dev/null +++ b/backend/cpp/llama/grpc-server.cpp @@ -0,0 +1,964 @@ +// llama.cpp gRPC C++ backend server +// +// Ettore Di Giacinto +// +// This is a gRPC server for llama.cpp compatible with the LocalAI proto +// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP, +// but modified to work with gRPC +// + +#include +#include +#include +#include + +#include "common.h" +#include "llama.h" +#include "grammar-parser.h" +#include "backend.pb.h" +#include "backend.grpc.pb.h" + +// include std::regex +#include +#include +#include +#include + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::Status; + + +using backend::HealthMessage; + + +// completion token output with probabilities +struct completion_token_output +{ + struct token_prob + { + llama_token tok; + float prob; + }; + + std::vector probs; + llama_token tok; +}; + +static size_t common_part(const std::vector &a, const std::vector &b) +{ + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) + { + } + return i; +} + +enum stop_type +{ + STOP_FULL, + STOP_PARTIAL, +}; + +static bool ends_with(const std::string &str, const std::string &suffix) +{ + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +static size_t find_partial_stop_string(const std::string &stop, + const std::string &text) +{ + if (!text.empty() && !stop.empty()) + { + const char text_last_char = text.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) + { + if (stop[char_index] == text_last_char) + { + const std::string current_partial = stop.substr(0, char_index + 1); + if (ends_with(text, current_partial)) + { + return text.size() - char_index - 1; + } + } + } + } + return std::string::npos; +} + +template +static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) +{ + std::string ret; + for (; begin != end; ++begin) + { + ret += llama_token_to_piece(ctx, *begin); + } + return ret; +} + + +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) +{ + std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) + { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + return out; +} + +struct llama_server_context +{ + bool stream = false; + bool has_next_token = false; + std::string generated_text; + std::vector generated_token_probs; + + size_t num_prompt_tokens = 0; + size_t num_tokens_predicted = 0; + size_t n_past = 0; + size_t n_remain = 0; + + // json prompt; + std::vector embd; + std::vector last_n_tokens; + + llama_model *model = nullptr; + llama_context *ctx = nullptr; + gpt_params params; + int n_ctx; + + grammar_parser::parse_state parsed_grammar; + llama_grammar *grammar = nullptr; + + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; + std::string stopping_word; + int32_t multibyte_pending = 0; + + std::mutex mutex; + + std::unique_lock lock() + { + return std::unique_lock(mutex); + } + + ~llama_server_context() + { + if (ctx) + { + llama_free(ctx); + ctx = nullptr; + } + if (model) + { + llama_free_model(model); + model = nullptr; + } + } + + void rewind() + { + params.antiprompt.clear(); + params.grammar.clear(); + num_prompt_tokens = 0; + num_tokens_predicted = 0; + generated_text = ""; + generated_text.reserve(n_ctx); + generated_token_probs.clear(); + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; + stopping_word = ""; + multibyte_pending = 0; + n_remain = 0; + n_past = 0; + + if (grammar != nullptr) { + llama_grammar_free(grammar); + grammar = nullptr; + } + } + + bool loadModel(const gpt_params ¶ms_) + { + printf("load model %s\n", params_.model.c_str()); + + params = params_; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == nullptr) + { + printf("unable to load model %s\n", params_.model.c_str()); + return false; + } + n_ctx = llama_n_ctx(ctx); + last_n_tokens.resize(n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + return true; + } + + std::vector tokenize_array(const char **prompts, bool add_bos) const + { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + std::vector prompt_tokens; + + + bool first = true; + // Iterate over prompts + for (const char **p = prompts; *p != nullptr; ++p) + { + auto s = std::string(*p); + std::vector pp; + if (first) + { + pp = ::llama_tokenize(ctx, s, add_bos); + first = false; + } + else + { + pp = ::llama_tokenize(ctx, s, false); + } + prompt_tokens.insert(prompt_tokens.end(), pp.begin(), pp.end()); + } + + + return prompt_tokens; + } + + std::vector tokenize_string(const char *prompt, bool add_bos) const + { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + std::vector prompt_tokens; + + auto s = std::string(prompt); + prompt_tokens = ::llama_tokenize(ctx, s, add_bos); + + return prompt_tokens; + } + + bool loadGrammar() + { + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + printf("grammar parse error"); + return false; + } + grammar_parser::print_grammar(stderr, parsed_grammar); + + { + auto it = params.logit_bias.find(llama_token_eos(ctx)); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + printf("EOS token is disabled, which will cause most grammars to fail"); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + return true; + } + + void loadInfill() + { + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize_string(params.input_prefix.c_str(), false); + auto suffix_tokens = tokenize_string(params.input_suffix.c_str(), false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(ctx)); + auto prompt_tokens = prefix_tokens; + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) + { + params.n_keep = (int)num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t)params.n_ctx) + { + printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); + // todo we probably want to cut from both sides + const int n_left = (params.n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + truncated = true; + prompt_tokens = new_tokens; + } + else + { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + embd = prompt_tokens; + if (n_past == num_prompt_tokens) + { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + n_past--; + } + + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + + has_next_token = true; + } + void loadPrompt(std::string prompt) + { + auto prompt_tokens = tokenize_string(prompt.c_str(), true); // always add BOS + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) + { + params.n_keep = (int)num_prompt_tokens; + } + params.n_keep = std::min(n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t)n_ctx) + { + const int n_left = (n_ctx - params.n_keep) / 2; + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + + truncated = true; + prompt_tokens = new_tokens; + } + else + { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + + + embd = prompt_tokens; + if (n_past == num_prompt_tokens) + { + // we have to evaluate at least 1 token to generate logits. + n_past--; + } + // since #3228 we now have to manually manage the KV cache + + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + has_next_token = true; + } + + void beginCompletion() + { + // number of tokens to keep when resetting context + n_remain = params.n_predict; + llama_set_rng_seed(ctx, params.seed); + } + + completion_token_output nextToken() + { + completion_token_output result; + result.tok = -1; + + if (embd.size() >= (size_t)n_ctx) + { + // Shift context + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) + { + embd[i - n_discard] = embd[i]; + } + embd.resize(embd.size() - n_discard); + + n_past -= n_discard; + + truncated = true; + + } + + bool tg = true; + while (n_past < embd.size()) + { + int n_eval = (int)embd.size() - n_past; + tg = n_eval == 1; + if (n_eval > params.n_batch) + { + n_eval = params.n_batch; + } + + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) + { + + has_next_token = false; + return result; + } + n_past += n_eval; + } + + if (params.n_predict == 0) + { + has_next_token = false; + result.tok = llama_token_eos(ctx); + return result; + } + + { + // out of user input, sample next token + std::vector candidates; + candidates.reserve(llama_n_vocab(model)); + + result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates); + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + const int32_t n_probs = params.n_probs; + if (params.temp <= 0 && n_probs > 0) + { + // For llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &candidates_p); + } + + for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) + { + result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + } + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(result.tok); + if (tg) { + num_tokens_predicted++; + } + } + + // add it to the context + embd.push_back(result.tok); + // decrement remaining sampling budget + --n_remain; + + if (!embd.empty() && embd.back() == llama_token_eos(ctx)) + { + // stopping_word = llama_token_to_piece(ctx, embd.back()); + has_next_token = false; + stopped_eos = true; + return result; + } + + has_next_token = params.n_predict == -1 || n_remain != 0; + return result; + } + + size_t findStoppingStrings(const std::string &text, const size_t last_token_size, + const stop_type type) + { + size_t stop_pos = std::string::npos; + for (const std::string &word : params.antiprompt) + { + size_t pos; + if (type == STOP_FULL) + { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + pos = text.find(word, from_pos); + } + else + { + pos = find_partial_stop_string(word, text); + } + if (pos != std::string::npos && + (stop_pos == std::string::npos || pos < stop_pos)) + { + if (type == STOP_FULL) + { + stopping_word = word; + stopped_word = true; + has_next_token = false; + } + stop_pos = pos; + } + } + return stop_pos; + } + + completion_token_output doCompletion() + { + auto token_with_probs = nextToken(); + + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); + generated_text += token_text; + + if (params.n_probs > 0) + { + generated_token_probs.push_back(token_with_probs); + } + + if (multibyte_pending > 0) + { + multibyte_pending -= token_text.size(); + } + else if (token_text.size() == 1) + { + const char c = token_text[0]; + // 2-byte characters: 110xxxxx 10xxxxxx + if ((c & 0xE0) == 0xC0) + { + multibyte_pending = 1; + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + } + else if ((c & 0xF0) == 0xE0) + { + multibyte_pending = 2; + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + } + else if ((c & 0xF8) == 0xF0) + { + multibyte_pending = 3; + } + else + { + multibyte_pending = 0; + } + } + + if (multibyte_pending > 0 && !has_next_token) + { + has_next_token = true; + n_remain++; + } + + if (!has_next_token && n_remain == 0) + { + stopped_limit = true; + } + + return token_with_probs; + } + + std::vector getEmbedding() + { + static const int n_embd = llama_n_embd(model); + if (!params.embedding) + { + printf("embedding disabled"); + return std::vector(n_embd, 0.0f); + } + const float *data = llama_get_embeddings(ctx); + std::vector embedding(data, data + n_embd); + return embedding; + } +}; + + +static void parse_options_completion(bool streaming,const backend::PredictOptions* predict, llama_server_context &llama) +{ + gpt_params default_params; + + llama.stream = streaming; + llama.params.n_predict = predict->tokens() == 0 ? -1 : predict->tokens(); + llama.params.top_k = predict->topk(); + llama.params.top_p = predict->topp(); + llama.params.tfs_z = predict->tailfreesamplingz(); + llama.params.typical_p = predict->typicalp(); + llama.params.repeat_last_n = predict->repeat(); + llama.params.temp = predict->temperature(); + llama.params.repeat_penalty = predict->penalty(); + llama.params.presence_penalty = predict->presencepenalty(); + llama.params.frequency_penalty = predict->frequencypenalty(); + llama.params.mirostat = predict->mirostat(); + llama.params.mirostat_tau = predict->mirostattau(); + llama.params.mirostat_eta = predict->mirostateta(); + llama.params.penalize_nl = predict->penalizenl(); + llama.params.n_keep = predict->nkeep(); + llama.params.seed = predict->seed(); + llama.params.grammar = predict->grammar(); + // llama.params.n_probs = predict-> + llama.params.prompt = predict->prompt(); + + llama.params.logit_bias.clear(); + + if (predict->ignoreeos()) + { + llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; + } + + // const auto &logit_bias = body.find("logit_bias"); + // if (logit_bias != body.end() && logit_bias->is_array()) + // { + // const int n_vocab = llama_n_vocab(llama.model); + // for (const auto &el : *logit_bias) + // { + // if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) + // { + // llama_token tok = el[0].get(); + // if (tok >= 0 && tok < n_vocab) + // { + // if (el[1].is_number()) + // { + // llama.params.logit_bias[tok] = el[1].get(); + // } + // else if (el[1].is_boolean() && !el[1].get()) + // { + // llama.params.logit_bias[tok] = -INFINITY; + // } + // } + // } + // } + // } + + llama.params.antiprompt.clear(); + for (const std::string& stopPrompt : predict->stopprompts()) { + if (!stopPrompt.empty()) + { + llama.params.antiprompt.push_back(stopPrompt); + } + } +} + + + +static void params_parse(const backend::ModelOptions* request, + gpt_params & params) { + + params.model = request->modelfile(); + // params.model_alias ?? + params.model_alias = request->modelfile(); + params.n_ctx = request->contextsize(); + params.memory_f16 = request->f16memory(); + params.n_threads = request->threads(); + params.n_gpu_layers = request->ngpulayers(); + params.n_batch = request->nbatch(); + + if (!request->tensorsplit().empty()) { + std::string arg_next = request->tensorsplit(); + + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + + GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + + for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) { + if (i_device < split_arg.size()) { + params.tensor_split[i_device] = std::stof(split_arg[i_device]); + } + else { + params.tensor_split[i_device] = 0.0f; + } + } + } + + if (!request->maingpu().empty()) { + params.main_gpu = std::stoi(request->maingpu()); + } + // TODO: lora needs also a scale factor + //params.lora_adapter = request->loraadapter(); + //params.lora_base = request->lorabase(); + params.use_mlock = request->mlock(); + params.use_mmap = request->mmap(); + params.embedding = request->embeddings(); +} + +static bool is_at_eob(llama_server_context &server_context, const llama_token *tokens, const size_t n_tokens) { + return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx); +} + +// Function matching type llama_beam_search_callback_fn_t. +// Custom callback example is called each time the beams lengths increase: +// * Show progress by printing ',' following by number of convergent beam tokens if any. +// * When all beams converge to a common prefix, they are made available in beams_state.beams[0]. +// This is also called when the stop condition is met. +// Collect tokens into std::vector response which is pointed to by callback_data. +static void beam_search_callback(void *callback_data, llama_beams_state beams_state) { + auto & llama = *static_cast(callback_data); + // Mark beams as EOS as needed. + for (size_t i = 0 ; i < beams_state.n_beams ; ++i) { + llama_beam_view& beam_view = beams_state.beam_views[i]; + if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) { + beam_view.eob = true; + } + } + printf(","); // Show progress + if (const size_t n = beams_state.common_prefix_length) { + llama.generated_token_probs.resize(llama.generated_token_probs.size() + n); + assert(0u < beams_state.n_beams); + const llama_token * tokens = beams_state.beam_views[0].tokens; + const auto map = [](llama_token tok) { return completion_token_output{{},tok}; }; + std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map); + printf("%zu", n); + } + fflush(stdout); +#if 0 // DEBUG: print current beams for this iteration + std::cout << "\n\nCurrent beams:\n"; + for (size_t i=0 ; i < beams_state.n_beams ; ++i) { + std::cout << "beams["<set_message("OK"); + return Status::OK; + } + + grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { + // Implement LoadModel RPC + gpt_params params; + params_parse(request, params); + + llama_backend_init(params.numa); + + // load the model + if (!llama.loadModel(params)) + { + result->set_message("Failed loading model"); + result->set_success(false); + return Status::CANCELLED; + } + result->set_message("Loading succeeded"); + result->set_success(true); + return Status::OK; + } + grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { + // Implement the streaming logic here based on the request options + // You can use writer->Write(response) to send a reply to the client + // and return grpc::Status::OK when the operation is complete. + auto lock = llama.lock(); + + llama.rewind(); + + llama_reset_timings(llama.ctx); + + parse_options_completion(false, request, llama); + + if (!llama.loadGrammar()) + { + //res.status = 400; + return Status::CANCELLED; + } + + llama.loadPrompt(request->prompt()); + llama.beginCompletion(); + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { + continue; + } + const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + + size_t pos = std::min(sent_count, llama.generated_text.size()); + + const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama.generated_text.erase( + llama.generated_text.begin() + pos + stop_pos, + llama.generated_text.end()); + pos = std::min(sent_count, llama.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + STOP_PARTIAL); + } + + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + + std::vector probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + backend::Reply reply; + reply.set_message(to_send); + + // Send the reply + writer->Write(reply); + } + } + + llama_print_timings(llama.ctx); + + llama.mutex.unlock(); + lock.release(); + return grpc::Status::OK; + } + + + grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { + auto lock = llama.lock(); + llama.rewind(); + llama_reset_timings(llama.ctx); + parse_options_completion(false, request, llama); + + if (!llama.loadGrammar()) + { + //res.status = 400; + return Status::CANCELLED; + } + + llama.loadPrompt(request->prompt()); + llama.beginCompletion(); + + if (llama.params.n_beams) { + // Fill llama.generated_token_probs vector with final beam. + llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, + llama.n_past, llama.n_remain); + // Translate llama.generated_token_probs to llama.generated_text. + append_to_generated_text_from_generated_token_probs(llama); + } else { + size_t stop_pos = std::string::npos; + + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); + + stop_pos = llama.findStoppingStrings(llama.generated_text, + token_text.size(), STOP_FULL); + } + + if (stop_pos == std::string::npos) { + stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); + } + if (stop_pos != std::string::npos) { + llama.generated_text.erase(llama.generated_text.begin() + stop_pos, + llama.generated_text.end()); + } + } + + auto probs = llama.generated_token_probs; + if (llama.params.n_probs > 0 && llama.stopped_word) { + const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); + probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); + } + reply->set_message(llama.generated_text); + return grpc::Status::OK; + } +}; + +void RunServer(const std::string& server_address) { + BackendServiceImpl service; + + ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + server->Wait(); +} + +int main(int argc, char** argv) { + std::string server_address("localhost:50051"); + + // Define long and short options + struct option long_options[] = { + {"addr", required_argument, nullptr, 'a'}, + {nullptr, 0, nullptr, 0} + }; + + // Parse command-line arguments + int option; + int option_index = 0; + while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) { + switch (option) { + case 'a': + server_address = optarg; + break; + default: + std::cerr << "Usage: " << argv[0] << " [--addr=
] or [-a
]" << std::endl; + return 1; + } + } + + RunServer(server_address); + return 0; +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 7773eb1e..5ad9500b 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -17,6 +17,7 @@ import ( const ( LlamaBackend = "llama" LlamaStableBackend = "llama-stable" + LLamaCPP = "llama-cpp" BloomzBackend = "bloomz" StarcoderBackend = "starcoder" GPTJBackend = "gptj" @@ -41,8 +42,9 @@ const ( ) var AutoLoadBackends []string = []string{ - LlamaBackend, + LLamaCPP, LlamaStableBackend, + LlamaBackend, Gpt4All, FalconBackend, GPTNeoXBackend, @@ -175,11 +177,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er } switch backend { - case LlamaBackend, LlamaStableBackend, GPTJBackend, DollyBackend, - MPTBackend, Gpt2Backend, FalconBackend, - GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend, - RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend: - return ml.LoadModel(o.model, ml.grpcModel(backend, o)) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") return ml.LoadModel(o.model, ml.grpcModel(Gpt4All, o)) @@ -187,7 +184,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data") return ml.LoadModel(o.model, ml.grpcModel(PiperBackend, o)) default: - return nil, fmt.Errorf("backend unsupported: %s", o.backendString) + return ml.LoadModel(o.model, ml.grpcModel(backend, o)) } }