mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-04 10:00:37 +00:00
whisper : fix gpu device selection
This commit is contained in:
parent
e940fbf283
commit
c719c5be54
@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
|
|||||||
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
|
||||||
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
|
||||||
|
|
||||||
|
ggml_backend_dev_t dev = nullptr;
|
||||||
|
|
||||||
|
int cnt = 0;
|
||||||
if (params.use_gpu) {
|
if (params.use_gpu) {
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
|
||||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||||
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
if (cnt == 0 || cnt == params.gpu_device) {
|
||||||
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
dev = dev_cur;
|
||||||
if (!result) {
|
}
|
||||||
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
|
||||||
|
if (++cnt > params.gpu_device) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
if (dev == nullptr) {
|
||||||
|
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||||
|
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
|
||||||
|
if (!result) {
|
||||||
|
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
|
||||||
@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
|
|||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
||||||
|
ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
|
||||||
|
|
||||||
if (!params.use_gpu) {
|
if (!params.use_gpu) {
|
||||||
return ggml_backend_cpu_buffer_type();
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we have a GPU device - use it
|
int cnt = 0;
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||||
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
|
if (cnt == 0 || cnt == params.gpu_device) {
|
||||||
return ggml_backend_dev_buffer_type(dev);
|
result = ggml_backend_dev_buffer_type(dev);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (++cnt > params.gpu_device) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_backend_cpu_buffer_type();
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// load the model from a ggml file
|
// load the model from a ggml file
|
||||||
|
Loading…
x
Reference in New Issue
Block a user