whisper : fix gpu device selection (#2728)
Some checks are pending
Bindings Tests (Ruby) / ubuntu-latest (push) Waiting to run
CI / ubuntu-latest (linux/amd64) (push) Waiting to run
CI / ubuntu-latest (linux/ppc64le) (push) Waiting to run
CI / ubuntu-latest-arm64 (linux/arm64) (push) Waiting to run
CI / ubuntu-latest-arm-v7 (linux/arm/v7) (push) Waiting to run
CI / macOS-latest (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-arm64 (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc-arm64 (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-arm-v7 (linux/arm/v7, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc-arm-v7 (linux/arm/v7, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Waiting to run
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Waiting to run

This commit is contained in:
Georgi Gerganov 2025-01-13 13:11:37 +02:00 committed by GitHub
parent e940fbf283
commit eb68324c86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) { static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
ggml_backend_dev_t dev = nullptr;
int cnt = 0;
if (params.use_gpu) { if (params.use_gpu) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i); ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
if (cnt == 0 || cnt == params.gpu_device) {
dev = dev_cur;
}
if (++cnt > params.gpu_device) {
break;
}
}
}
}
if (dev == nullptr) {
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
return nullptr;
}
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) { if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
} }
return result;
}
}
}
return nullptr; return result;
} }
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) { static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
} }
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) { static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
if (!params.use_gpu) { if (!params.use_gpu) {
return ggml_backend_cpu_buffer_type(); return result;
} }
// if we have a GPU device - use it int cnt = 0;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i); ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); if (cnt == 0 || cnt == params.gpu_device) {
return ggml_backend_dev_buffer_type(dev); result = ggml_backend_dev_buffer_type(dev);
}
if (++cnt > params.gpu_device) {
break;
}
} }
} }
return ggml_backend_cpu_buffer_type(); return result;
} }
// load the model from a ggml file // load the model from a ggml file