whisper : support ggml_conv with CUDA and Metal (#1473)

* ggml : add CUDA support for ggml_conv

* whisper : remove ggml_repeat for conv bias + single backend

* cuda : fix im2col kernel

* metal : add im2col support + mul mat-vec f16 x f16

* bench-all : add q4 models
This commit is contained in:
Georgi Gerganov 2023-11-10 22:26:50 +02:00 committed by GitHub
parent c99e290a7f
commit 933c5bef97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 490 additions and 1111 deletions

View File

@ -18,11 +18,11 @@ else
fi
models=( \
"tiny" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
"base" "base-q5_0" "base-q5_1" "base-q8_0" \
"small" "small-q5_0" "small-q5_1" "small-q8_0" \
"medium" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
"large" "large-q5_0" "large-q5_1" "large-q8_0" \
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
"small" "small-q4_0" "small-q4_1" "small-q5_0" "small-q5_1" "small-q8_0" \
"medium" "medium-q4_0" "medium-q4_1" "medium-q5_0" "medium-q5_1" "medium-q8_0" \
"large" "large-q4_0" "large-q4_1" "large-q5_0" "large-q5_1" "large-q8_0" \
)
if [ "$encoder_only" -eq 0 ]; then

View File

@ -4476,6 +4476,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
*dsti = __float2half(*xi);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@ -4729,6 +4736,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int ofs0, int ofs1, int IW, int IH, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
const int offset_dst =
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
} else {
dst[offset_dst] = __float2half(0.0f);
}
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
@ -5618,6 +5644,16 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@ -5701,6 +5737,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
}
static void im2col_f32_f16_cuda(const float * x, half * dst,
int OH, int IW, int IH, int OW, int IC,
int KH, int KW, int N, int ofs0, int ofs1,
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
dim3 block_nums(IC, OH, OW);
dim3 block_dims(N, KH, KW);
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 256
@ -6483,7 +6528,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
size_t dst_f16_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
@ -6659,6 +6704,45 @@ inline void ggml_cuda_op_alibi(
(void) src1_dd;
}
inline void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
OH, IW, IH, OW, IC, KH, KW, N,
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;
}
inline void ggml_cuda_op_diag_mask_inf(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -7549,6 +7633,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@ -7580,6 +7667,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
@ -7943,6 +8034,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_ALIBI:
func = ggml_cuda_alibi;
break;
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
default:
return false;
}

View File

@ -86,6 +86,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@ -114,6 +115,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(im2col_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(im2col_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(im2col_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@ -1139,6 +1145,7 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
nrows = 4;
} break;
@ -1146,13 +1153,18 @@ void ggml_metal_graph_compute(
{
nth0 = 32;
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
nrows = 4;
}
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
nrows = 4;
}
} break;
@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_IM2COL:
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int32_t N = src1->ne[is_2D ? 3 : 2];
const int32_t IC = src1->ne[is_2D ? 2 : 1];
const int32_t IH = is_2D ? src1->ne[1] : 1;
const int32_t IW = src1->ne[0];
const int32_t KH = is_2D ? src0->ne[1] : 1;
const int32_t KW = src0->ne[0];
const int32_t OH = is_2D ? dst->ne[2] : 1;
const int32_t OW = dst->ne[1];
const int32_t CHW = IC * KH * KW;
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
default: GGML_ASSERT(false);
};
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:

View File

@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
}
}
#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F16;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (half) x[i] * (half) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
device const half4 * y4 = (device const half4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}
kernel void kernel_mul_mv_f16_f32_1row(
device const char * src0,
device const char * src1,
@ -1229,6 +1302,39 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
kernel void kernel_im2col_f16(
device const float * x,
device half * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
const int32_t offset_dst =
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
} else {
dst[offset_dst] = 0.0f;
}
}
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,

1050
ggml.c

File diff suppressed because it is too large Load Diff

19
ggml.h
View File

@ -403,13 +403,8 @@ extern "C" {
GGML_OP_ROPE_BACK,
GGML_OP_ALIBI,
GGML_OP_CLAMP,
GGML_OP_CONV_1D,
GGML_OP_CONV_1D_STAGE_0, // internal
GGML_OP_CONV_1D_STAGE_1, // internal
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_CONV_2D,
GGML_OP_CONV_2D_STAGE_0, // internal
GGML_OP_CONV_2D_STAGE_1, // internal
GGML_OP_IM2COL,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
@ -1398,6 +1393,18 @@ extern "C" {
float min,
float max);
GGML_API struct ggml_tensor * ggml_im2col(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
bool is_2D);
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@ -588,7 +588,6 @@ struct whisper_kv_cache {
};
struct whisper_model_data {
ggml_backend_buffer_t buffer_conv; // TODO: tmp until GPU support for conv
ggml_backend_buffer_t buffer_main;
};
@ -818,23 +817,9 @@ struct whisper_context {
whisper_state * state = nullptr;
ggml_backend_t backend_cpu = nullptr;
ggml_backend_t backend_gpu = nullptr;
ggml_backend_t backend = nullptr;
std::string path_model; // populated by whisper_init_from_file_with_params()
ggml_backend_t backend_kv() const {
return backend_gpu ? backend_gpu : backend_cpu;
}
// TODO: always on CPU until we have a GPU support for conv
ggml_backend_t backend_conv() const {
return backend_cpu;
}
ggml_backend_t backend_main() const {
return backend_gpu ? backend_gpu : backend_cpu;
}
};
struct whisper_global {
@ -1192,16 +1177,16 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// encoder
{
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
// map by name
model.tensors["encoder.positional_embedding"] = model.e_pe;
@ -1394,45 +1379,30 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
#endif
if (backend_gpu) {
wctx.backend_gpu = backend_gpu;
wctx.backend = backend_gpu;
} else {
wctx.backend_gpu = nullptr;
wctx.backend = ggml_backend_cpu_init();
}
// always add the CPU backend as a fallback
wctx.backend_cpu = ggml_backend_cpu_init();
}
{
size_t size_conv = 0;
size_t size_main = 0;
for (const auto & t : model.tensors) {
if (t.first.find("conv") != std::string::npos) {
size_conv += ggml_nbytes(t.second) + ggml_tensor_overhead();
} else {
size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
}
size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
}
model.data->buffer_conv = ggml_backend_alloc_buffer(wctx.backend_conv(), size_conv);
model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend_main(), size_main);
model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend, size_main);
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_conv()), size_conv / 1024.0 / 1024.0);
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_main()), size_main / 1024.0 / 1024.0);
WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1024.0 / 1024.0);
}
ggml_allocr * alloc_conv = ggml_allocr_new_from_buffer(model.data->buffer_conv);
ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main);
// allocate tensors in the backend buffers
{
for (const auto & t : model.tensors) {
if (t.first.find("conv") != std::string::npos) {
ggml_allocr_alloc(alloc_conv, t.second);
} else {
ggml_allocr_alloc(alloc_main, t.second);
}
ggml_allocr_alloc(alloc_main, t.second);
}
}
@ -1475,45 +1445,67 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias");
if (!is_conv_bias) {
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
const bool is_conv = name.find("conv") != std::string::npos;
ggml_backend * backend = is_conv ? wctx.backend_conv() : wctx.backend_main();
ggml_backend_t backend = wctx.backend;
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
if (ggml_backend_is_cpu(backend)
if ((ggml_backend_is_cpu(backend)
#ifdef GGML_USE_METAL
|| ggml_backend_is_metal(backend)
#endif
) {
) && !is_conv_bias) {
// for the CPU and Metal backend, we can read directly into the tensor
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
BYTESWAP_TENSOR(tensor);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
loader->read(loader->context, read_buf.data(), read_buf.size());
// we repeat the 2 bias tensors along dim 0:
// [1, 512] -> [3000, 512] (conv1.bias)
// [1, 512] -> [1500, 512] (conv2.bias)
if (is_conv_bias) {
loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]);
float * data_f32 = (float *) read_buf.data();
for (int64_t y = 0; y < tensor->ne[1]; ++y) {
const int64_t yy = tensor->ne[1] - y - 1;
const float val = data_f32[yy];
for (int64_t x = 0; x < tensor->ne[0]; ++x) {
data_f32[yy*tensor->ne[0] + x] = val;
}
}
} else {
loader->read(loader->context, read_buf.data(), read_buf.size());
}
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
}
@ -1532,7 +1524,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
}
ggml_allocr_free(alloc_conv);
ggml_allocr_free(alloc_main);
wctx.t_load_us = ggml_time_us() - t_start_us;
@ -1595,7 +1586,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
float * dst = wstate.inp_mel.data();
memset(dst, 0, ggml_nbytes(mel));
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
for (int j = 0; j < mel_inp.n_mel; ++j) {
@ -1613,20 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv(
// convolution + gelu
{
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_1_b,
cur),
cur);
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
//cur = ggml_add(ctx0,
// ggml_repeat(ctx0,
// model.e_conv_1_b,
// cur),
// cur);
cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_2_b,
cur),
cur);
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
//cur = ggml_add(ctx0,
// ggml_repeat(ctx0,
// model.e_conv_2_b,
// cur),
// cur);
cur = ggml_gelu(ctx0, cur);
}
@ -1685,6 +1678,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_allocr * alloc = wstate.alloc_encode.alloc;
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
//ggml_allocr_alloc(alloc, cur);
//if (!ggml_allocr_is_measure(alloc)) {
// ggml_backend_tensor_copy(wstate.embd_conv, cur);
//}
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
struct ggml_tensor * KQscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(alloc, KQscale);
@ -1693,13 +1694,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float));
}
struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state);
ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
ggml_backend_tensor_copy(wstate.embd_conv, cur);
}
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
@ -1939,12 +1933,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
ggml_allocr * alloc = wstate.alloc_cross.alloc;
struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_allocr_alloc(alloc, cur);
//struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
//ggml_allocr_alloc(alloc, cur);
if (!ggml_allocr_is_measure(alloc)) {
ggml_backend_tensor_copy(wstate.embd_enc, cur);
}
//if (!ggml_allocr_is_measure(alloc)) {
// ggml_backend_tensor_copy(wstate.embd_enc, cur);
//}
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
struct ggml_tensor * Kscale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(alloc, Kscale);
@ -2022,15 +2017,15 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);
if (!whisper_encode_external(wstate)) {
if (ggml_backend_is_cpu(wctx.backend_conv())) {
ggml_backend_cpu_set_n_threads(wctx.backend_conv(), n_threads);
if (ggml_backend_is_cpu(wctx.backend)) {
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
}
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(wctx.backend_conv())) {
ggml_backend_metal_set_n_cb(wctx.backend_conv(), n_threads);
if (ggml_backend_is_metal(wctx.backend)) {
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
}
#endif
ggml_backend_graph_compute(wctx.backend_conv(), gf);
ggml_backend_graph_compute(wctx.backend, gf);
}
}
@ -2044,15 +2039,15 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);
if (ggml_backend_is_cpu(wctx.backend_main())) {
ggml_backend_cpu_set_n_threads(wctx.backend_main(), n_threads);
if (ggml_backend_is_cpu(wctx.backend)) {
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
}
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(wctx.backend_main())) {
ggml_backend_metal_set_n_cb(wctx.backend_main(), n_threads);
if (ggml_backend_is_metal(wctx.backend)) {
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
}
#endif
ggml_backend_graph_compute(wctx.backend_main(), gf);
ggml_backend_graph_compute(wctx.backend, gf);
}
// cross
@ -2065,15 +2060,15 @@ static bool whisper_encode_internal(
ggml_allocr_alloc_graph(alloc, gf);
if (ggml_backend_is_cpu(wctx.backend_main())) {
ggml_backend_cpu_set_n_threads(wctx.backend_main(), n_threads);
if (ggml_backend_is_cpu(wctx.backend)) {
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
}
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(wctx.backend_main())) {
ggml_backend_metal_set_n_cb(wctx.backend_main(), n_threads);
if (ggml_backend_is_metal(wctx.backend)) {
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
}
#endif
ggml_backend_graph_compute(wctx.backend_main(), gf);
ggml_backend_graph_compute(wctx.backend, gf);
}
wstate.t_encode_us += ggml_time_us() - t_start_us;
@ -2464,15 +2459,15 @@ static bool whisper_decode_internal(
logits = gf->nodes[gf->n_nodes - 1];
if (ggml_backend_is_cpu(wctx.backend_main())) {
ggml_backend_cpu_set_n_threads(wctx.backend_main(), n_threads);
if (ggml_backend_is_cpu(wctx.backend)) {
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
}
#ifdef GGML_USE_METAL
if (ggml_backend_is_metal(wctx.backend_main())) {
ggml_backend_metal_set_n_cb(wctx.backend_main(), n_threads);
if (ggml_backend_is_metal(wctx.backend)) {
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
}
#endif
ggml_backend_graph_compute(wctx.backend_main(), gf);
ggml_backend_graph_compute(wctx.backend, gf);
}
// extract logits for all N tokens
@ -2915,7 +2910,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
whisper_state * state = new whisper_state;
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend_kv(), ctx->itype, ctx->model.hparams.n_text_ctx)) {
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
@ -2926,7 +2921,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend_kv(), ctx->itype, ctx->model.hparams.n_audio_ctx)) {
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
delete state;
return nullptr;
@ -2968,7 +2963,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// conv allocator
{
whisper_allocr_graph_init(state->alloc_conv, ctx->backend_conv(),
whisper_allocr_graph_init(state->alloc_conv, ctx->backend,
[&]() {
return whisper_build_graph_conv(*ctx, *state, 0);
});
@ -2978,7 +2973,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// encoder allocator
if (!whisper_encode_external(*state)) {
whisper_allocr_graph_init(state->alloc_encode, ctx->backend_main(),
whisper_allocr_graph_init(state->alloc_encode, ctx->backend,
[&]() {
return whisper_build_graph_encoder(*ctx, *state);
});
@ -2988,7 +2983,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// cross allocator
{
whisper_allocr_graph_init(state->alloc_cross, ctx->backend_main(),
whisper_allocr_graph_init(state->alloc_cross, ctx->backend,
[&]() {
return whisper_build_graph_cross(*ctx, *state);
});
@ -2998,7 +2993,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// decoder allocator
{
whisper_allocr_graph_init(state->alloc_decode, ctx->backend_main(),
whisper_allocr_graph_init(state->alloc_decode, ctx->backend,
[&]() {
const auto & hparams = ctx->model.hparams;
@ -3273,7 +3268,6 @@ void whisper_free(struct whisper_context * ctx) {
ggml_free(ctx->model.ctx);
}
if (ctx->model.data) {
ggml_backend_buffer_free(ctx->model.data->buffer_conv);
ggml_backend_buffer_free(ctx->model.data->buffer_main);
delete ctx->model.data;
@ -3281,11 +3275,7 @@ void whisper_free(struct whisper_context * ctx) {
whisper_free_state(ctx->state);
ggml_backend_free(ctx->backend_cpu);
if (ctx->backend_gpu) {
ggml_backend_free(ctx->backend_gpu);
}
ggml_backend_free(ctx->backend);
delete ctx;
}
@ -4583,7 +4573,7 @@ int whisper_full_with_state(
if (decoder.kv_self.ctx == nullptr) {
decoder.kv_self = state->decoders[0].kv_self;
if (!kv_cache_reinit(decoder.kv_self, ctx->backend_kv())) {
if (!kv_cache_reinit(decoder.kv_self, ctx->backend)) {
WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j);
return -4;
}