mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
cuda : use CUBLAS_COMPTE_F32 insted of CUBLAS_COMPUTE_F16
This commit is contained in:
parent
f52e74d4dc
commit
c8b3bc6a0d
41
ggml-cuda.cu
41
ggml-cuda.cu
@ -7432,7 +7432,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
||||
}
|
||||
|
||||
__global__ static void k_compute_batched_ptrs(
|
||||
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
|
||||
const half * src0_as_f16, const half * src1_as_f16, float * dst_f32,
|
||||
const void ** ptrs_src, void ** ptrs_dst,
|
||||
int ne12, int ne13,
|
||||
int ne23,
|
||||
@ -7452,7 +7452,7 @@ __global__ static void k_compute_batched_ptrs(
|
||||
|
||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f32 + i12* nb2 + i13* nb3 ;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@ -7509,9 +7509,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
@ -7519,8 +7516,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
#if 0
|
||||
// use cublasGemmEx
|
||||
@ -7533,10 +7530,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
&alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||
&beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
}
|
||||
@ -7548,11 +7545,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||
&alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||
&beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||
ne12*ne13,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
// use cublasGemmBatchedEx
|
||||
@ -7569,7 +7566,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
src0_as_f16, src1_as_f16, dst_f16,
|
||||
src0_as_f16, src1_as_f16, dst_ddf,
|
||||
ptrs_src, ptrs_dst,
|
||||
ne12, ne13,
|
||||
ne23,
|
||||
@ -7582,11 +7579,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
|
||||
&alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||
&beta, ( void **) (ptrs_dst + 0*ne23), CUDA_R_32F, ne01,
|
||||
ne23,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
if (ptrs_src_s != 0) {
|
||||
@ -7598,11 +7595,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
}
|
||||
#endif
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
Loading…
Reference in New Issue
Block a user