whisper : use backend registry (#0)

This commit is contained in:
Georgi Gerganov 2024-11-20 15:32:34 +02:00
parent 9db070a3c5
commit 37c88027e1

View File

@ -1,43 +1,19 @@
#include "whisper.h" #include "whisper.h"
#include "ggml-cpu.h"
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#ifdef WHISPER_USE_COREML #ifdef WHISPER_USE_COREML
#include "coreml/whisper-encoder.h" #include "coreml/whisper-encoder.h"
#endif #endif
#include "ggml-cpu.h"
#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif
#ifdef GGML_USE_SYCL
#include "ggml-sycl.h"
#endif
#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif
#ifdef GGML_USE_BLAS
#include "ggml-blas.h"
#endif
#ifdef WHISPER_USE_OPENVINO #ifdef WHISPER_USE_OPENVINO
#include "openvino/whisper-openvino-encoder.h" #include "openvino/whisper-openvino-encoder.h"
#endif #endif
#ifdef GGML_USE_CANN
#include "ggml-cann.h"
#endif
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include <atomic> #include <atomic>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@ -195,14 +171,13 @@ static bool ggml_graph_compute_helper(
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
if (ggml_backend_is_cpu(backend)) { ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_cpu_set_n_threads(backend, n_threads); ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (fn_set_n_threads) {
fn_set_n_threads(backend, n_threads);
} }
#ifdef GGML_USE_BLAS
if (ggml_backend_is_blas(backend)) {
ggml_backend_blas_set_n_threads(backend, n_threads);
}
#endif
} }
bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS; bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
@ -1256,67 +1231,23 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
} }
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) { static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
ggml_backend_t result = NULL;
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
#ifdef GGML_USE_CUDA
if (params.use_gpu) { if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
result = ggml_backend_cuda_init(params.gpu_device); ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (!result) { if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
}
return result;
}
} }
} }
#endif
#ifdef GGML_USE_METAL return nullptr;
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
result = ggml_backend_metal_init();
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
} else if (!ggml_backend_metal_supports_family(result, 7)) {
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
ggml_backend_free(result);
result = NULL;
}
}
#endif
#ifdef GGML_USE_SYCL
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
result = ggml_backend_sycl_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
}
}
#endif
#ifdef GGML_USE_VULKAN
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
result = ggml_backend_vk_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_vk_init() failed\n", __func__);
}
}
#endif
#ifdef GGML_USE_CANN
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
result = ggml_backend_cann_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_cann_init() failed\n", __func__);
}
}
#endif
GGML_UNUSED(params);
return result;
} }
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) { static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@ -1328,17 +1259,19 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
result.push_back(backend_gpu); result.push_back(backend_gpu);
} }
#ifdef GGML_USE_BLAS // ACCEL backends
{ for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__); ggml_backend_dev_t dev = ggml_backend_dev_get(i);
ggml_backend_t backend_blas = ggml_backend_blas_init(); if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
if (!backend_blas) { WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
WHISPER_LOG_ERROR("%s: ggml_backend_blas_init() failed\n", __func__); ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
} else { if (!backend) {
result.push_back(backend_blas); WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
continue;
}
result.push_back(backend);
} }
} }
#endif
GGML_UNUSED(params); GGML_UNUSED(params);
@ -1348,33 +1281,20 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
} }
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) { static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
ggml_backend_buffer_type_t result = nullptr; if (!params.use_gpu) {
return ggml_backend_cpu_buffer_type();
}
params.use_gpu || (result = ggml_backend_cpu_buffer_type()); // if we have a GPU device - use it
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
return ggml_backend_dev_buffer_type(dev);
}
}
#ifdef GGML_USE_CUDA return ggml_backend_cpu_buffer_type();
result || (result = ggml_backend_cuda_buffer_type(params.gpu_device));
#endif
#ifdef GGML_USE_METAL
result || (result = ggml_backend_metal_buffer_type());
#endif
#ifdef GGML_USE_SYCL
result || (result = ggml_backend_sycl_buffer_type(params.gpu_device));
#endif
#ifdef GGML_USE_VULKAN
result || (result = ggml_backend_vk_buffer_type(params.gpu_device));
#endif
#ifdef GGML_USE_CANN
result || (result == ggml_backend_cann_buffer_type(params.gpu_device));
#endif
result || (result = ggml_backend_cpu_buffer_type());
return result;
} }
// load the model from a ggml file // load the model from a ggml file
@ -3668,8 +3588,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
// TODO: temporary call to force backend registry initialization
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
whisper_context * ctx = new whisper_context; whisper_context * ctx = new whisper_context;
@ -7427,6 +7346,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) { static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
(void) level; (void) level;
(void) user_data; (void) user_data;
#ifndef WHISPER_DEBUG
if (level == GGML_LOG_LEVEL_DEBUG) {
return;
}
#endif
fputs(text, stderr); fputs(text, stderr);
fflush(stderr); fflush(stderr);
} }