diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index af971cab..0167b833 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -53,6 +53,7 @@ struct whisper_params { int32_t capture_id = -1; int32_t max_tokens = 32; int32_t audio_ctx = 0; + int32_t n_gpu_layers = 0; float vad_thold = 0.6f; float freq_thold = 100.0f; @@ -90,6 +91,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } @@ -134,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7s] number of layers to store in VRAM\n", params.n_gpu_layers); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); @@ -268,6 +271,8 @@ int main(int argc, char ** argv) { auto lmparams = llama_model_default_params(); if (!params.use_gpu) { lmparams.n_gpu_layers = 0; + } else { + lmparams.n_gpu_layers = params.n_gpu_layers; } struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lmparams);