mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-01 06:50:41 +00:00
185 lines
5.0 KiB
Common Lisp
185 lines
5.0 KiB
Common Lisp
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
|
|
//------------------------------------------------------------------------------
|
|
// cpy
|
|
//------------------------------------------------------------------------------
|
|
|
|
kernel void kernel_cpy_f16_f16(
|
|
global half * src0,
|
|
ulong offset0,
|
|
global half * dst,
|
|
ulong offsetd,
|
|
int ne00,
|
|
int ne01,
|
|
int ne02,
|
|
int ne03,
|
|
ulong nb00,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne0,
|
|
int ne1,
|
|
int ne2,
|
|
int ne3,
|
|
ulong nb0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = (global half*)((global char*)src0 + offset0);
|
|
dst = (global half*)((global char*)dst + offsetd);
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0);
|
|
|
|
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
int i3 = n / (ne2*ne1*ne0);
|
|
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
|
global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
dst_data[i00] = src[0];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_cpy_f16_f32(
|
|
global half * src0,
|
|
ulong offset0,
|
|
global float * dst,
|
|
ulong offsetd,
|
|
int ne00,
|
|
int ne01,
|
|
int ne02,
|
|
int ne03,
|
|
ulong nb00,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne0,
|
|
int ne1,
|
|
int ne2,
|
|
int ne3,
|
|
ulong nb0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
|
|
src0 = (global half*)((global char*)src0 + offset0);
|
|
dst = (global float*)((global char*)dst + offsetd);
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0);
|
|
|
|
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
int i3 = n / (ne2*ne1*ne0);
|
|
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
|
global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
dst_data[i00] = src[0];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_cpy_f32_f16(
|
|
global float * src0,
|
|
ulong offset0,
|
|
global half * dst,
|
|
ulong offsetd,
|
|
int ne00,
|
|
int ne01,
|
|
int ne02,
|
|
int ne03,
|
|
ulong nb00,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne0,
|
|
int ne1,
|
|
int ne2,
|
|
int ne3,
|
|
ulong nb0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = (global float*)((global char*)src0 + offset0);
|
|
dst = (global half*)((global char*)dst + offsetd);
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0);
|
|
|
|
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
int i3 = n / (ne2*ne1*ne0);
|
|
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
|
global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
|
dst_data[i00] = src[0];
|
|
}
|
|
}
|
|
|
|
kernel void kernel_cpy_f32_f32(
|
|
global float * src0,
|
|
ulong offset0,
|
|
global float * dst,
|
|
ulong offsetd,
|
|
int ne00,
|
|
int ne01,
|
|
int ne02,
|
|
int ne03,
|
|
ulong nb00,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
int ne0,
|
|
int ne1,
|
|
int ne2,
|
|
int ne3,
|
|
ulong nb0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = (global float*)((global char*)src0 + offset0);
|
|
dst = (global float*)((global char*)dst + offsetd);
|
|
|
|
int i03 = get_group_id(2);
|
|
int i02 = get_group_id(1);
|
|
int i01 = get_group_id(0);
|
|
|
|
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
|
|
int i3 = n / (ne2*ne1*ne0);
|
|
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
|
|
global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
|
global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
|
|
dst_data[i00] = src[0];
|
|
}
|
|
}
|