diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8739baa2..ceb66170 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1647,7 +1647,7 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation @@ -1670,7 +1670,7 @@ static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const gg ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); @@ -2413,32 +2413,304 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + graph_node_properties->node_address = node->data; + graph_node_properties->node_op = node->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + graph_node_properties->ne[i] = node->ne[i]; + graph_node_properties->nb[i] = node->nb[i]; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + } +} + +static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + if (node->data != graph_node_properties->node_address && + node->op != GGML_OP_CPY && + node->op != GGML_OP_VIEW) { + return false; + } + + if (node->op != graph_node_properties->node_op) { + return false; + } + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != graph_node_properties->ne[i]) { + return false; + } + if (node->nb[i] != graph_node_properties->nb[i]) { + return false; + } + } + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i] && + node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_CPY && + node->op != GGML_OP_VIEW + ) { + return false; + } + } + return true; +} + GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_cuda_set_device(cuda_ctx->device); - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; +#ifdef USE_CUDA_GRAPH + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; + // Objects required for CUDA Graph + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + + bool use_cuda_graph = true; + bool cuda_graph_update_required = false; + // pointer to CUDA cpy kernel, which is required to identify + // kernel parameters which need updated in the graph for each token + void * ggml_cuda_cpy_fn_ptr = nullptr; + + if (cuda_ctx->cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { + cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__); +#endif + } + } + + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. + // Also disable for multi-gpu for now. TO DO investigate + if (disable_cuda_graphs_due_to_env + || cuda_ctx->cuda_graph->disable_due_to_gpu_arch + || cuda_ctx->cuda_graph->disable_due_to_too_many_updates + || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { + use_cuda_graph = false; + } + + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; } + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + cuda_ctx->cuda_graph->updated_kernel_arg.clear(); + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture #ifndef NDEBUG - assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j] != nullptr) { - assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); + fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__); +#endif + } + + if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (node->op == GGML_OP_CPY) { + // store the copy op parameter which changes with each token. + cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + if (ggml_cuda_cpy_fn_ptr == nullptr) { + // store a pointer to the copy op CUDA kernel to identify it later + ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + } + } + + if (!use_cuda_graph) { + break; } } + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if (cuda_graph_update_required) { + cuda_ctx->cuda_graph->number_consecutive_updates++; + } else { + cuda_ctx->cuda_graph->number_consecutive_updates = 0; + } + + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); +#endif + } + } + + if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } + +#else + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; +#endif // USE_CUDA_GRAPH + + bool graph_evaluated_or_captured = false; + + while (!graph_evaluated_or_captured) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); + } + } #endif - bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); - if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + if (!ok) { + fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } } - GGML_ASSERT(ok); + +#ifdef USE_CUDA_GRAPH + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (cuda_ctx->cuda_graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); + cuda_ctx->cuda_graph->graph = nullptr; + } + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + +#if 0 + if (disable_cuda_graphs_due_to_failed_capture) { + use_cuda_graph = false; + cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; +#ifndef NDEBUG + fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__); +#endif + } else { + graph_evaluated_or_captured = true; // CUDA graph has been captured + } +#endif + graph_evaluated_or_captured = true; // CUDA graph has been captured + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } + } + + if (use_cuda_graph) { + if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } + + // Perform update to graph (if required for this token), and change copy parameter (required for every token) + + if (cuda_graph_update_required) { + // Extract nodes from graph + if (cuda_ctx->cuda_graph->num_nodes == 0) { + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); + } + // Subsequent call with non-null argument gets nodes + cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); + if (cuda_ctx->cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); + + // Loop over nodes, and extract kernel parameters from each node + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + cudaGraphNodeType node_type; + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + if (node_type == cudaGraphNodeTypeKernel) { + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + if (stat == cudaErrorInvalidDeviceFunction) { + // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. + // We don't need to update blas nodes, so clear error and move on. + cudaGetLastError(); + } else { + GGML_ASSERT(stat == cudaSuccess); + } + } + } + } + } + + // One of the arguments to the copy kernel is updated for each token, hence we need to + // replace that argument with the updated value in the CUDA graph + if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured + int k = 0; + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { + char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); + cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); + } + } + } + + // Update graph executable + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + fprintf(stderr, "%s: CUDA graph update failed\n", __func__); +#endif + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } + // Launch graph + CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); +#else + graph_evaluated_or_captured = true; +#endif // USE_CUDA_GRAPH } return GGML_STATUS_SUCCESS; diff --git a/ggml-cuda/clamp.cu b/ggml-cuda/clamp.cu index 379ded04..8009a3e3 100644 --- a/ggml-cuda/clamp.cu +++ b/ggml-cuda/clamp.cu @@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream); - CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b2627b7b..a4197f11 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #if defined(GGML_USE_HIPBLAS) #include @@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs }; + +#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#define USE_CUDA_GRAPH +#endif + +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; +}; + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector nodes; + std::vector params; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; + int number_consecutive_updates = 0; + std::vector ggml_graph_properties; + std::vector updated_kernel_arg; +#endif +}; + struct ggml_backend_cuda_context { int device; std::string name; @@ -534,6 +572,8 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + std::unique_ptr cuda_graph; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 75e50c98..830e2d75 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_ } to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { - int id; switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q5_1: return dequantize_block_cuda; case GGML_TYPE_Q8_0: - CUDA_CHECK(cudaGetDevice(&id)); - if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) { + if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) { return dequantize_block_q8_0_f16_cuda; } return dequantize_block_cuda; diff --git a/ggml-cuda/cpy.cu b/ggml-cuda/cpy.cu index 16d9c8ff..12d741f0 100644 --- a/ggml-cuda/cpy.cu +++ b/ggml-cuda/cpy.cu @@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; ggml_cuda_cpy(ctx, src0, dst); } + +void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { + return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_f32_f16; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_f32_f16; + } else { + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ASSERT(false); + } +} + diff --git a/ggml-cuda/cpy.cuh b/ggml-cuda/cpy.cuh index f0b2c453..79616742 100644 --- a/ggml-cuda/cpy.cuh +++ b/ggml-cuda/cpy.cuh @@ -5,3 +5,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 60d6616a..7948f1b1 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( #if QK_K == 256 - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; @@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; int mmq_x, mmq_y, nwarps; diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu index 39655900..65cc1bca 100644 --- a/ggml-cuda/mmvq.cu +++ b/ggml-cuda/mmvq.cu @@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda( GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; @@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q( const int64_t ne0 = dst->ne[0]; - int id; - CUDA_CHECK(cudaGetDevice(&id)); + int id = ggml_cuda_get_device(); // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into diff --git a/ggml-cuda/scale.cu b/ggml-cuda/scale.cu index 6e3617d1..1405e066 100644 --- a/ggml-cuda/scale.cu +++ b/ggml-cuda/scale.cu @@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, dst->op_params, sizeof(float)); scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); - CUDA_CHECK(cudaGetLastError()); }