mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-02 08:43:02 +00:00
cuda : enable CUDA Graph on CUDA Toolkit < 12.x (llama/12394)
* Enable CUDA Graph on CTK < 12.x `cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x. * Fix compilation errors with MUSA * Disable CUDA Graph for MUSA
This commit is contained in:
parent
db6e8056b5
commit
cfc2560e41
@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
|
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
|
||||||
#define USE_CUDA_GRAPH
|
#define USE_CUDA_GRAPH
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
|||||||
|
|
||||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= 12000
|
||||||
cudaGraphExecUpdateResultInfo result_info;
|
cudaGraphExecUpdateResultInfo result_info;
|
||||||
#ifdef __HIP_PLATFORM_AMD__
|
|
||||||
hipGraphNode_t errorNode;
|
|
||||||
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
|
||||||
#else
|
|
||||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||||
#endif
|
#else
|
||||||
|
cudaGraphNode_t errorNode;
|
||||||
|
cudaGraphExecUpdateResult result_info;
|
||||||
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||||
|
#endif // CUDART_VERSION >= 12000
|
||||||
|
|
||||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||||
|
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
2
ggml/src/ggml-cuda/vendors/hip.h
vendored
@ -112,7 +112,7 @@
|
|||||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||||
#define cudaGraphLaunch hipGraphLaunch
|
#define cudaGraphLaunch hipGraphLaunch
|
||||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
||||||
#define cudaGraphNodeType hipGraphNodeType
|
#define cudaGraphNodeType hipGraphNodeType
|
||||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||||
#define cudaGraphInstantiate hipGraphInstantiate
|
#define cudaGraphInstantiate hipGraphInstantiate
|
||||||
|
3
ggml/src/ggml-cuda/vendors/musa.h
vendored
3
ggml/src/ggml-cuda/vendors/musa.h
vendored
@ -119,7 +119,7 @@
|
|||||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||||
#define cudaGraphExec_t musaGraphExec_t
|
#define cudaGraphExec_t musaGraphExec_t
|
||||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
||||||
#define cudaGraphGetNodes musaGraphGetNodes
|
#define cudaGraphGetNodes musaGraphGetNodes
|
||||||
#define cudaGraphInstantiate musaGraphInstantiate
|
#define cudaGraphInstantiate musaGraphInstantiate
|
||||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||||
@ -132,6 +132,7 @@
|
|||||||
#define cudaGraph_t musaGraph_t
|
#define cudaGraph_t musaGraph_t
|
||||||
#define cudaKernelNodeParams musaKernelNodeParams
|
#define cudaKernelNodeParams musaKernelNodeParams
|
||||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||||
|
#define cudaStreamBeginCapture musaStreamBeginCapture
|
||||||
#define cudaStreamEndCapture musaStreamEndCapture
|
#define cudaStreamEndCapture musaStreamEndCapture
|
||||||
|
|
||||||
typedef mt_bfloat16 nv_bfloat16;
|
typedef mt_bfloat16 nv_bfloat16;
|
||||||
|
@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
|
|||||||
add_compile_definitions(GGML_USE_MUSA)
|
add_compile_definitions(GGML_USE_MUSA)
|
||||||
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
||||||
|
|
||||||
if (GGML_CUDA_GRAPHS)
|
|
||||||
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_CUDA_FORCE_MMQ)
|
if (GGML_CUDA_FORCE_MMQ)
|
||||||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||||
endif()
|
endif()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user