mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-22 00:13:35 +00:00
cuda : synchronize graph capture and cublas handle destruction (llama/14288)
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread
This commit is contained in:
committed by
Georgi Gerganov
parent
018b2d340e
commit
33d1f0a3e0
@ -19,10 +19,10 @@
|
|||||||
#endif
|
#endif
|
||||||
#include "ggml-common.h"
|
#include "ggml-common.h"
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -767,21 +767,7 @@ struct ggml_backend_cuda_context {
|
|||||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
~ggml_backend_cuda_context() {
|
~ggml_backend_cuda_context();
|
||||||
if (copy_event != nullptr) {
|
|
||||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
|
||||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
|
||||||
if (streams[i][j] != nullptr) {
|
|
||||||
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (cublas_handles[i] != nullptr) {
|
|
||||||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaStream_t stream(int device, int stream) {
|
cudaStream_t stream(int device, int stream) {
|
||||||
if (streams[device][stream] == nullptr) {
|
if (streams[device][stream] == nullptr) {
|
||||||
|
@ -48,6 +48,7 @@
|
|||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <charconv>
|
#include <charconv>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
#include <condition_variable>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
@ -55,9 +56,8 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <stdint.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdarg.h>
|
#include <stdarg.h>
|
||||||
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -515,6 +515,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
|
|||||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
|
||||||
|
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
|
||||||
|
|
||||||
|
static std::mutex ggml_cuda_lock;
|
||||||
|
static std::condition_variable ggml_cuda_lock_cv;
|
||||||
|
static std::atomic<int> ggml_cuda_lock_counter;
|
||||||
|
|
||||||
|
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
|
||||||
|
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
|
||||||
|
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
|
||||||
|
|
||||||
|
if (copy_event != nullptr) {
|
||||||
|
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||||
|
}
|
||||||
|
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||||
|
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||||
|
if (streams[i][j] != nullptr) {
|
||||||
|
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cublas_handles[i] != nullptr) {
|
||||||
|
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// cuda buffer
|
// cuda buffer
|
||||||
|
|
||||||
struct ggml_backend_cuda_buffer_context {
|
struct ggml_backend_cuda_buffer_context {
|
||||||
@ -2689,6 +2716,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||||||
|
|
||||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||||
|
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
|
||||||
|
ggml_cuda_lock_cv.notify_all();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
||||||
}
|
}
|
||||||
@ -2764,7 +2796,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
|
if (use_cuda_graph && cuda_graph_update_required) {
|
||||||
|
// Start CUDA graph capture
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||||
|
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user