diff --git a/Makefile b/Makefile index 2c7d0259..2645ddd0 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ DETECT_LIBS?=true # llama.cpp versions GOLLAMA_REPO?=https://github.com/go-skynet/go-llama.cpp GOLLAMA_VERSION?=2b57a8ae43e4699d3dc5d1496a1ccd42922993be -CPPLLAMA_VERSION?=274ec65af6e54039eb95cb44904af5c945dca1fa +CPPLLAMA_VERSION?=c27ac678dd393af0da9b8acf10266e760c8a0912 # whisper.cpp version WHISPER_REPO?=https://github.com/ggerganov/whisper.cpp diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index ea5c4e34..d553d35d 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -2228,6 +2228,35 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama // } // } +const std::vector kv_cache_types = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, +}; + +static ggml_type kv_cache_type_from_str(const std::string & s) { + for (const auto & type : kv_cache_types) { + if (ggml_type_name(type) == s) { + return type; + } + } + throw std::runtime_error("Unsupported cache type: " + s); +} + +static std::string get_all_kv_cache_types() { + std::ostringstream msg; + for (const auto & type : kv_cache_types) { + msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", "); + } + return msg.str(); +} + static void params_parse(const backend::ModelOptions* request, common_params & params) { @@ -2242,10 +2271,10 @@ static void params_parse(const backend::ModelOptions* request, // params.model_alias ?? params.model_alias = request->modelfile(); if (!request->cachetypekey().empty()) { - params.cache_type_k = request->cachetypekey(); + params.cache_type_k = kv_cache_type_from_str(request->cachetypekey()); } if (!request->cachetypevalue().empty()) { - params.cache_type_v = request->cachetypevalue(); + params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue()); } params.n_ctx = request->contextsize(); //params.memory_f16 = request->f16memory();