opencl: split ggml-opencl.cl into multiple files and cleanup (llama/12886)

---------

Co-authored-by: Shangqing Gu <quic_shawngu@quicinc.com>
This commit is contained in:
lhez 2025-04-24 17:46:49 +03:00 committed by Georgi Gerganov
parent fe4acb33e3
commit 88c3cecd43
37 changed files with 5781 additions and 281 deletions

View File

@ -54,16 +54,41 @@ function(ggml_opencl_add_kernel KNAME)
endfunction()
set(GGML_OPENCL_KERNELS
ggml-opencl
ggml-opencl_mm
ggml-opencl_cvt
ggml-opencl_gemv_noshuffle
ggml-opencl_gemv_noshuffle_general
ggml-opencl_mul_mat_Ab_Bi_8x4
ggml-opencl_transpose_16
ggml-opencl_transpose_32
ggml-opencl_transpose_32_16
ggml-opencl_im2col
add
clamp
cpy
cvt
diag_mask_inf
gelu
gemv_noshuffle_general
gemv_noshuffle
get_rows
im2col_f32
im2col_f16
mul_mat_Ab_Bi_8x4
mul_mv_f16_f16
mul_mv_f16_f32_1row
mul_mv_f16_f32_l4
mul_mv_f16_f32
mul_mv_f32_f32
mul_mv_q4_0_f32
mul_mv_q4_0_f32_v
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q6_k
mul
norm
relu
rms_norm
rope
scale
silu
softmax_4_f32
softmax_4_f16
softmax_f32
softmax_f16
transpose
)
foreach (K ${GGML_OPENCL_KERNELS})

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// add
//------------------------------------------------------------------------------
// general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient
kernel void kernel_add(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
int i13 = i03 % ne13;
int i12 = i02 % ne12;
int i11 = i01 % ne11;
global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const int i10 = i0 % ne10;
*((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
}
}
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
global float4 * src0,
ulong offset0,
global float4 * src1,
ulong offset1,
global float4 * dst,
ulong offsetd,
int ne
) {
src0 = (global float4*)((global char*)src0 + offset0);
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float4*)((global char*)dst + offsetd);
// This performs better than using %.
uint gid = get_global_id(0);
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
dst[gid] = src0[gid] + src1[idx1];
}

View File

@ -0,0 +1,20 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// clamp
//------------------------------------------------------------------------------
kernel void kernel_clamp(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
float min,
float max
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = src0[get_global_id(0)] < min ?
min :
(src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]);
}

View File

@ -0,0 +1,184 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// cpy
//------------------------------------------------------------------------------
kernel void kernel_cpy_f16_f16(
global half * src0,
ulong offset0,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i3 = n / (ne2*ne1*ne0);
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f16_f32(
global half * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i3 = n / (ne2*ne1*ne0);
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f16(
global float * src0,
ulong offset0,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i3 = n / (ne2*ne1*ne0);
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}
kernel void kernel_cpy_f32_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i3 = n / (ne2*ne1*ne0);
int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0];
}
}

View File

@ -0,0 +1,118 @@
//------------------------------------------------------------------------------
// This file is contains kernels for data conversion.
// These kernels are used when loading the model, so its performance is less
// important.
//------------------------------------------------------------------------------
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_0(
global struct block_q4_0 * src0,
global uchar * dst_q,
global half * dst_d
) {
global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK4_0/2; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q4_0(
global uchar * src_q,
global half * src_d,
global struct block_q4_0 * dst
) {
global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK4_0/2; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0_noshuffle
// Flatten q4_0 weights and unshuffle the bits
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_0_noshuffle(
global struct block_q4_0 * src0,
global uchar * dst_q,
global half * dst_d
) {
global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK4_0/4; ++i) {
uchar x0 = b->qs[2*i + 0];
uchar x1 = b->qs[2*i + 1];
q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
#ifdef ADRENO_GPU
// Workaround for adreno - must have the following printf statement for
// the kernel to work properly. Otherwise it produces incorrect result.
// convert_uchar above also seems necessary.
// Compare against a large number so that it does not print anything.
// get_sub_group_local_id() also works.
if (get_global_id(0) == 65536*4096) {
printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
}
#endif
}
}

View File

@ -0,0 +1,58 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// diag_mask_inf kernels
//------------------------------------------------------------------------------
kernel void kernel_diag_mask_inf(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int n_past
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int i02 = get_global_id(2);
int i01 = get_global_id(1);
int i00 = get_global_id(0);
if (i00 > n_past + i01) {
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
}
}
kernel void kernel_diag_mask_inf_8(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd,
int ne00,
int ne01,
int n_past
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
int i = 2*get_global_id(0);
dst[i+0] = src0[i+0];
dst[i+1] = src0[i+1];
int i4 = 4*i;
int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
int i01 = i4/(ne00); i4 -= i01*ne00;
int i00 = i4;
for (int k = 3; k >= 0; --k) {
if (i00 + 4 + k <= n_past + i01) {
break;
}
(&dst[i+1])[k] = -INFINITY;
if (i00 + k > n_past + i01) {
(&dst[i])[k] = -INFINITY;
}
}
}

View File

@ -0,0 +1,62 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// gelu
//------------------------------------------------------------------------------
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
kernel void kernel_gelu(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_quick(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}
kernel void kernel_gelu_quick_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
}

View File

@ -0,0 +1,268 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
// assume
#define QK4_0 32
#define N_SIMDGROUP 4
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
float shared_y; \
shared_y = sub_group_broadcast(y.s0, 0); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 0); \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 0); \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 0); \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 0); \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 0); \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 0); \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 0); \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s0, 1); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 1); \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 1); \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 1); \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 1); \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 1); \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 1); \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 1); \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
shared_y = sub_group_broadcast(y.s0, 2); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 2); \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 2); \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 2); \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 2); \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 2); \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 2); \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 2); \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s0, 3); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 3); \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 3); \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 3); \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 3); \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 3); \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 3); \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 3); \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
float8 shared_y; \
shared_y = sub_group_broadcast(y, 0); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
shared_y = sub_group_broadcast(y, 1); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
shared_y = sub_group_broadcast(y, 2); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
shared_y = sub_group_broadcast(y, 3); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
__kernel void kernel_gemv_noshuffle(
__read_only image1d_buffer_t src0_q, // quantized A
global half2 * src0_d, // A scales
__read_only image1d_buffer_t src1, // B
ulong offset1, // offset to B (0)
global float * dst, // C
ulong offsetd, // offset to C (0)
uint K, // K
int ne01, // M
int ne02, // 1
int ne10, // K
int ne12, // 1
int ne0, // M
int ne1, // N
int r2, // 1
int r3)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
__private uint4 regA;
__private half2 regS;
__private float8 regB;
__private float2 totalSum = (float2)(0.0f);
// loop along K in block granularity, skip 4 blocks every iter
for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
// first 4 fibers in each wave load 8 B values to its private scope
if (slid < 4) {
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
}
// load half weights for two blocks in consecutive rows
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAT
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
#endif // VECTOR_SUB_GROUP_BROADCAT
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAT
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
#endif // VECTOR_SUB_GROUP_BROADCAT
}
// reduction in local memory, assumes #wave=4
__local float2 reduceLM[SIMDGROUP_WIDTH * 3];
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 2 outputs per fiber in wave 0
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
vstore2(totalSum, 0, &(dst[gid * 2]));
}
}

View File

@ -0,0 +1,274 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#endif
// assume
#define QK4_0 32
#define N_SIMDGROUP 4
#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
float shared_y; \
shared_y = sub_group_broadcast(y.s0, 0); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 0); \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 0); \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 0); \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 0); \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 0); \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 0); \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 0); \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s0, 1); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 1); \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 1); \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 1); \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 1); \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 1); \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 1); \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 1); \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
shared_y = sub_group_broadcast(y.s0, 2); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 2); \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 2); \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 2); \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 2); \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 2); \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 2); \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 2); \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s0, 3); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s1, 3); \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s2, 3); \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s3, 3); \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s4, 3); \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s5, 3); \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s6, 3); \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
shared_y = sub_group_broadcast(y.s7, 3); \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
float8 shared_y; \
shared_y = sub_group_broadcast(y, 0); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
shared_y = sub_group_broadcast(y, 1); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
shared_y = sub_group_broadcast(y, 2); \
total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
shared_y = sub_group_broadcast(y, 3); \
total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
__kernel void kernel_gemv_noshuffle(
__read_only image1d_buffer_t src0_q, // quantized A
global half2 * src0_d, // A scales
__read_only image1d_buffer_t src1, // B
ulong offset1, // offset to B (0)
global float * dst, // C
ulong offsetd, // offset to C (0)
int ne00, // K
int ne01, // M
int ne02, // 1
int ne10, // K
int ne12, // 1
int ne0, // M
int ne1, // N
int r2, // 1
int r3)
{
uint groupId = get_local_id(1);
uint gid = get_global_id(0);
ushort slid = get_sub_group_local_id();
uint K = ne00;
uint M = ne01;
uint LINE_STRIDE_A = M / 2;
uint BLOCK_STRIDE_A = N_SIMDGROUP * M;
__private uint4 regA;
__private half2 regS;
__private float8 regB;
__private float2 totalSum = (float2)(0.0f);
// loop along K in block granularity, skip 4 blocks every iter
for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
// first 4 fibers in each wave load 8 B values to its private scope
if (slid < 4) {
regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
}
// load half weights for two blocks in consecutive rows
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAT
dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
#endif // VECTOR_SUB_GROUP_BROADCAT
regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
#ifdef VECTOR_SUB_GROUP_BROADCAT
dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
#else
dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
#endif // VECTOR_SUB_GROUP_BROADCAT
}
// reduction in local memory, assumes #wave=4
__local float2 reduceLM[SIMDGROUP_WIDTH * 3];
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
barrier(CLK_LOCAL_MEM_FENCE);
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
// 2 outputs per fiber in wave 0
if (groupId == 0) {
dst = (global float*)((global char*)dst + offsetd);
vstore2(totalSum, 0, &(dst[gid * 2]));
}
}

View File

@ -0,0 +1,163 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
#define QK4_0 32
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// dequantize_q4_0_f32, dequantize_q4_0_f16
//------------------------------------------------------------------------------
void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
global ushort * qs = ((global ushort *)xb + 1);
float d1 = il ? (xb->d / 16.h) : xb->d;
float d2 = d1 / 256.f;
float md = -8.h * xb->d;
ushort mask0 = il ? 0x00F0 : 0x000F;
ushort mask1 = mask0 << 8;
reg->s0 = d1 * (qs[0] & mask0) + md;
reg->s1 = d2 * (qs[0] & mask1) + md;
reg->s2 = d1 * (qs[1] & mask0) + md;
reg->s3 = d2 * (qs[1] & mask1) + md;
reg->s4 = d1 * (qs[2] & mask0) + md;
reg->s5 = d2 * (qs[2] & mask1) + md;
reg->s6 = d1 * (qs[3] & mask0) + md;
reg->s7 = d2 * (qs[3] & mask1) + md;
reg->s8 = d1 * (qs[4] & mask0) + md;
reg->s9 = d2 * (qs[4] & mask1) + md;
reg->sa = d1 * (qs[5] & mask0) + md;
reg->sb = d2 * (qs[5] & mask1) + md;
reg->sc = d1 * (qs[6] & mask0) + md;
reg->sd = d2 * (qs[6] & mask1) + md;
reg->se = d1 * (qs[7] & mask0) + md;
reg->sf = d2 * (qs[7] & mask1) + md;
}
//------------------------------------------------------------------------------
// get_rows
//------------------------------------------------------------------------------
kernel void kernel_get_rows_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb10,
ulong nb11,
ulong nb1,
ulong nb2
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int i10 = get_group_id(0);
int i11 = get_group_id(1);
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int i02 = i11;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
kernel void kernel_get_rows_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb10,
ulong nb11,
ulong nb1,
ulong nb2
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int i10 = get_group_id(0);
int i11 = get_group_id(1);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int i02 = i11;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
kernel void kernel_get_rows_q4_0(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb10,
ulong nb11,
ulong nb1,
ulong nb2
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
const int NL = 2;
int i10 = get_group_id(0);
int i11 = get_group_id(1);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int i02 = i11;
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
float16 temp;
dequantize_q4_0_f32(
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
}
}

View File

@ -0,0 +1,57 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void kernel_im2col_f16(
global float * src1,
ulong offset1,
global half * dst,
ulong offsetd,
ulong batch_offset,
ulong delta_offset,
long IW,
long IH,
long IC,
long OW,
long OH,
long KW,
long KH,
long pelements,
long CHW,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1
) {
long i = get_global_id(0);
if (i >= pelements) {
return;
}
src1 = (global float*)((global char*)src1 + offset1);
dst = (global half*)((global char*)dst + offsetd);
long ksize = OW * (KH > 1 ? KW : 1);
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
long ix = i % OW;
long oh = get_group_id(1);
long batch = get_group_id(2) / IC;
long ic = get_group_id(2) % IC;
long iiw = ix * s0 + kx * d0 - p0;
long iih = oh * s1 + ky * d1 - p1;
long offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
long offset_src = ic * delta_offset + batch * batch_offset;
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
}
}

View File

@ -0,0 +1,57 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void kernel_im2col_f32(
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
ulong batch_offset,
ulong delta_offset,
long IW,
long IH,
long IC,
long OW,
long OH,
long KW,
long KH,
long pelements,
long CHW,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1
) {
long i = get_global_id(0);
if (i >= pelements) {
return;
}
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
long ksize = OW * (KH > 1 ? KW : 1);
long kx = i / ksize;
long kd = kx * ksize;
long ky = (i - kd) / OW;
long ix = i % OW;
long oh = get_group_id(1);
long batch = get_group_id(2) / IC;
long ic = get_group_id(2) % IC;
long iiw = ix * s0 + kx * d0 - p0;
long iih = oh * s1 + ky * d1 - p1;
long offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
long offset_src = ic * delta_offset + batch * batch_offset;
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
}
}

View File

@ -0,0 +1,79 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// mul
//------------------------------------------------------------------------------
kernel void kernel_mul(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
int i13 = i03 % ne13;
int i12 = i02 % ne12;
int i11 = i01 % ne11;
global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const int i10 = i0 % ne10;
*((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));
}
}
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_mul_row(
global float4 * src0,
ulong offset0,
global float4 * src1,
ulong offset1,
global float4 * dst,
ulong offsetd,
int ne
) {
src0 = (global float4*)((global char*)src0 + offset0);
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float4*)((global char*)dst + offsetd);
// This performs better than using %.
uint gid = get_global_id(0);
uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
dst[gid] = src0[gid] * src1[idx1];
}

View File

@ -0,0 +1,139 @@
// src0_q, src0_d, src1 are transposed as a preprocessing step
// 4-bit weights are transposed in groups of 4 (unsigned short int)
// consider weights originally "next to each other", now "on top of each other"
// each fiber computes a 8x4 tile of output elements
// using unshuffled weights
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#ifdef cl_qcom_reqd_sub_group_size
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif
kernel void kernel_mul_mat_Ab_Bi_8x4(
global const ushort * src0_q, // quantized A
global const half * src0_d, // A scales
__read_only image1d_buffer_t src1, // B (1d image)
global float * dst, // C
int m, // M
int n, // N with padding
int k, // K
int n_no_padding // N without padding
) {
int m_4 = m >> 2;
int n_4 = n >> 2;
int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements
half8 B; // registers for activations
half4 dequantized_weights; // registers for dequantized weights
__global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights
__global const half* scale_ptr = src0_d + gx_2; // pointer for scales
for(int i=0; i<k; i+=4){ //loop through K dimension
B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
// keep (i/4) and (i/32) in parenthesis, rounds down
// load 4 consecutive groups of 4 weights
ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); // (i/4) because weights grouped in 4s
// load 4 consecutive scales
half4 scale = vload4(0, scale_ptr + (i/32)*(m));// (i/32) because 1 scale per 32 elements
// j=0
dequantized_weights.s0 = ((bits4.s0 & (0x000F)) - 8) * scale.s0; // dequantize a row of the 16 weights
dequantized_weights.s1 = ((bits4.s1 & (0x000F)) - 8) * scale.s1;
dequantized_weights.s2 = ((bits4.s2 & (0x000F)) - 8) * scale.s2;
dequantized_weights.s3 = ((bits4.s3 & (0x000F)) - 8) * scale.s3;
c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=1
B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
dequantized_weights.s0 = (((bits4.s0 & (0x00F0)) >> 4) - 8) * scale.s0; // dequantize a row of the 16 weights
dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;
dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;
dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;
c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=2
B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights
dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;
dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;
dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;
c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
// j=3
B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights
dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;
dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;
dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;
c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
}
int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements
// conditional check if store is to a valid location. Required when N is not a multiple of 8
// if statements allow registers to be reused for each store
// provides a performance boost due to reduced register footprint, which increases number of concurrent waves
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}

View File

@ -0,0 +1,118 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define N_F16_F16 4
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_f16_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3)
{
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int r0 = get_group_id(0);
int rb = get_group_id(1)*N_F16_F16;
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global half * x = (global half *) (src0 + offset_src0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global half * y = (global half *) (src1 + offset_src1);
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
sumf += (half) x[i] * (half) y[i];
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
global half4 * x4 = (global half4 *)x;
for (int row = 0; row < N_F16_F16; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global half * y = (global half *) (src1 + offset_src1);
global half4 * y4 = (global half4 *) y;
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
sumf += (half) x4[i].s0 * y4[i].s0;
sumf += (half) x4[i].s1 * y4[i].s1;
sumf += (half) x4[i].s2 * y4[i].s2;
sumf += (half) x4[i].s3 * y4[i].s3;
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) {
all_sum += (half) x[i] * y[i];
}
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}

View File

@ -0,0 +1,118 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define N_F16_F32 4
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_f16_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int r0 = get_group_id(0);
int rb = get_group_id(1)*N_F16_F32;
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global half * x = (global half *) (src0 + offset_src0);
if (ne00 < 128) {
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global float * y = (global float *) (src1 + offset_src1);
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
sumf += convert_float(x[i]) * y[i];
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
global half4 * x4 = (global half4 *)x;
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global float * y = (global float *) (src1 + offset_src1);
global float4 * y4 = (global float4 *) y;
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
sumf += convert_float(x4[i].s0) * y4[i].s0;
sumf += convert_float(x4[i].s1) * y4[i].s1;
sumf += convert_float(x4[i].s2) * y4[i].s2;
sumf += convert_float(x4[i].s3) * y4[i].s3;
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) {
all_sum += (float) x[i] * y[i];
}
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}

View File

@ -0,0 +1,94 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_f16_f32_1row(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global half * x = (global half *) (src0 + offset_src0);
global float * y = (global float *) (src1 + offset_src1);
float sumf = 0;
if (ne00 < 128) {
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} else {
global half4 * x4 = (global half4 *) x;
global float4 * y4 = (global float4 *) y;
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
sumf += (float) x4[i].s0 * y4[i].s0;
sumf += (float) x4[i].s1 * y4[i].s1;
sumf += (float) x4[i].s2 * y4[i].s2;
sumf += (float) x4[i].s3 * y4[i].s3;
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) {
all_sum += (float) x[i] * y[i];
}
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}

View File

@ -0,0 +1,84 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
// Assumes row size (ne00) is a multiple of 4
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_f16_f32_l4(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int nrows = ne11;
int r0 = get_group_id(0);
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global half4 * x4 = (global half4 *) (src0 + offset_src0);
for (int r1 = 0; r1 < nrows; ++r1) {
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global float4 * y4 = (global float4 *) (src1 + offset_src1);
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
sumf += convert_float(x4[i].s0) * y4[i].s0;
sumf += convert_float(x4[i].s1) * y4[i].s1;
sumf += convert_float(x4[i].s2) * y4[i].s2;
sumf += convert_float(x4[i].s3) * y4[i].s3;
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}

View File

@ -0,0 +1,118 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define N_F32_F32 4
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_f32_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int r0 = get_group_id(0);
int rb = get_group_id(1)*N_F32_F32;
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
global float * x = (global float *) (src0 + offset_src0);
if (ne00 < 128) {
for (int row = 0; row < N_F32_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global float * y = (global float *) (src1 + offset_src1);
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
global float4 * x4 = (global float4 *)x;
for (int row = 0; row < N_F32_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global float * y = (global float *) (src1 + offset_src1);
global float4 * y4 = (global float4 *) y;
float sumf = 0;
for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
sumf += (float) x4[i].s0 * y4[i].s0;
sumf += (float) x4[i].s1 * y4[i].s1;
sumf += (float) x4[i].s2 * y4[i].s2;
sumf += (float) x4[i].s3 * y4[i].s3;
}
float all_sum = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) {
all_sum += (float) x[i] * y[i];
}
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
}

View File

@ -0,0 +1,192 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// mul_vec_q_n_f32
//------------------------------------------------------------------------------
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_4_0_dot_y(
global struct block_q4_0 * qb_curr,
float sumy,
private float * yl,
int il
) {
float d = qb_curr->d;
float2 acc = 0.f;
global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
for (int i = 0; i < 8; i+=2) {
acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ yl[i + 9] * (qs[i / 2] & 0xF000);
}
return d * (sumy * -8.f + acc.s0 + acc.s1);
}
#ifdef INTEL_GPU
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32(
global void * src0,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
// (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
// id of a SIMD group in the grid.
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; // src1 vector cache
float sumf[N_DST]={0.f};
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_0 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
for (int i = 0; i < 8; i += 2) {
sumy += yb[i] + yb[i+1];
yl[i+0] = yb[i+ 0];
yl[i+1] = yb[i+ 1]/256.f;
sumy += yb[i+16] + yb[i+17];
yl[i+8] = yb[i+16]/16.f;
yl[i+9] = yb[i+17]/4096.f;
}
for (int row = 0; row < N_DST; row++) {
sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il);
}
// One thread in a SIMD group (i.e., subgroup) handles a half block,
// hence then entire SIMD group handles SIMDWIDTH/2 blocks.
// y points to the activation matrix (of type float). Therefore for
// one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
// SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
// floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
yb += QK4_0 * (N_SIMDWIDTH/2);
}
// The above does not work for Adreno - it produces incorrect results for
// row = 1, 2, 3 and only row = 0 gives the correct result.
// If N_DST is changed, the below array must be initialized accordingly.
// This also seems to perform better on Intel.
float tot[N_DST] = {
sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),
sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};
for (int row = 0; row < N_DST; ++row) {
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_q4_0_f32(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,307 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
inline float mm_block_q_4_0_dot_y_flat(
global uchar * x,
global half * dh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
global ushort * qs = ((global ushort *)x + il/2);
float acc = 0.f;
acc += yl.s0 * (qs[0] & 0x000F);
acc += yl.s1 * (qs[0] & 0x0F00);
acc += yl.s8 * (qs[0] & 0x00F0);
acc += yl.s9 * (qs[0] & 0xF000);
acc += yl.s2 * (qs[1] & 0x000F);
acc += yl.s3 * (qs[1] & 0x0F00);
acc += yl.sa * (qs[1] & 0x00F0);
acc += yl.sb * (qs[1] & 0xF000);
acc += yl.s4 * (qs[2] & 0x000F);
acc += yl.s5 * (qs[2] & 0x0F00);
acc += yl.sc * (qs[2] & 0x00F0);
acc += yl.sd * (qs[2] & 0xF000);
acc += yl.s6 * (qs[3] & 0x000F);
acc += yl.s7 * (qs[3] & 0x0F00);
acc += yl.se * (qs[3] & 0x00F0);
acc += yl.sf * (qs[3] & 0xF000);
return d * (sumy * -8.f + acc);
}
#ifdef INTEL_GPU
#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix)
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
#elif defined (ADRENO_GPU)
#define N_DST 16
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
//
// This variant performs 1d blocking with 16x output.
// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix).
//
inline void mul_mat_q_n_f32_1d_16x_flat(
global uchar * src0_q,
global half * src0_d,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const int nb = ne00/QK4_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
// (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
// a SIMD group in the grid. Each SIMD group produces N_DST values in the
// result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
// Currently with llama2 7B, im is always 0.
// TODO: how to handle im/gqa*(nb*ne0)?
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
// The number of scales is the same as the number of blocks.
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_d;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix*QK4_0 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0.f;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il);
sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il);
sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il);
sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il);
sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il);
sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il);
sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il);
sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il);
yb += QK4_0 * (N_SIMDWIDTH/2);
}
float16 tot = (float16)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7),
sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9),
sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb),
sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd),
sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
if (first_row + 4 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
}
if (first_row + 5 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
}
if (first_row + 6 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
}
if (first_row + 7 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
}
if (first_row + 8 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8;
}
if (first_row + 9 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9;
}
if (first_row + 10 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa;
}
if (first_row + 11 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb;
}
if (first_row + 12 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc;
}
if (first_row + 13 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd;
}
if (first_row + 14 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se;
}
if (first_row + 15 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat(
global uchar * src0_q,
global half * src0_d,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,265 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
inline float mm_block_q_4_0_dot_y_flat(
global uchar * x,
global half * dh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
global ushort * qs = ((global ushort *)x + il/2);
float acc = 0.f;
acc += yl.s0 * (qs[0] & 0x000F);
acc += yl.s1 * (qs[0] & 0x0F00);
acc += yl.s8 * (qs[0] & 0x00F0);
acc += yl.s9 * (qs[0] & 0xF000);
acc += yl.s2 * (qs[1] & 0x000F);
acc += yl.s3 * (qs[1] & 0x0F00);
acc += yl.sa * (qs[1] & 0x00F0);
acc += yl.sb * (qs[1] & 0xF000);
acc += yl.s4 * (qs[2] & 0x000F);
acc += yl.s5 * (qs[2] & 0x0F00);
acc += yl.sc * (qs[2] & 0x00F0);
acc += yl.sd * (qs[2] & 0xF000);
acc += yl.s6 * (qs[3] & 0x000F);
acc += yl.s7 * (qs[3] & 0x0F00);
acc += yl.se * (qs[3] & 0x00F0);
acc += yl.sf * (qs[3] & 0xF000);
return d * (sumy * -8.f + acc);
}
#ifdef INTEL_GPU
#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix)
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
#elif defined (ADRENO_GPU)
#define N_DST 8
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
//
// This variant performs 1d blocking with 8x output.
// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix).
//
inline void mul_mat_q_n_f32_1d_8x_flat(
global uchar * src0_q,
global half * src0_d,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const int nb = ne00/QK4_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
// (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
// a SIMD group in the grid. Each SIMD group produces N_DST values in the
// result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
// Currently with llama2 7B, im is always 0.
// TODO: how to handle im/gqa*(nb*ne0)?
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
// The number of scales is the same as the number of blocks.
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_d;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix*QK4_0 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0.f;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
yb += QK4_0 * (N_SIMDWIDTH/2);
}
float8 tot = (float8)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
if (first_row + 4 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
}
if (first_row + 5 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
}
if (first_row + 6 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
}
if (first_row + 7 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat(
global uchar * src0_q,
global half * src0_d,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,272 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
// This function requires the original shuffled weights.
// As a reminder, the original weights are shuffled so that (q[0], q[16]) are
// packed together in a byte, so are (q[1], q[17]) and so on.
inline float block_q_4_0_dot_y_flat(
global uchar * x,
global half * dh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
global ushort * qs = ((global ushort *)x + il/2);
float acc = 0.f;
acc += yl.s0 * (qs[0] & 0x000F);
acc += yl.s1 * (qs[0] & 0x0F00);
acc += yl.s8 * (qs[0] & 0x00F0);
acc += yl.s9 * (qs[0] & 0xF000);
acc += yl.s2 * (qs[1] & 0x000F);
acc += yl.s3 * (qs[1] & 0x0F00);
acc += yl.sa * (qs[1] & 0x00F0);
acc += yl.sb * (qs[1] & 0xF000);
acc += yl.s4 * (qs[2] & 0x000F);
acc += yl.s5 * (qs[2] & 0x0F00);
acc += yl.sc * (qs[2] & 0x00F0);
acc += yl.sd * (qs[2] & 0xF000);
acc += yl.s6 * (qs[3] & 0x000F);
acc += yl.s7 * (qs[3] & 0x0F00);
acc += yl.se * (qs[3] & 0x00F0);
acc += yl.sf * (qs[3] & 0xF000);
return d * (sumy * -8.f + acc);
}
//
// This variant outputs 8 values.
//
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 8 // each SIMD group works on 8 rows
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // assuming SIMD group size is 32
#elif defined (ADRENO_GPU)
#define N_DST 8
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32_8x_flat(
global uchar * src0_q,
global half * src0_d,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
// (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
// a SIMD group in the grid. Each SIMD group produces N_DST values in the
// result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
// Currently with llama2 7B, im is always 0.
// TODO: how to handle im/gqa*(nb*ne0)?
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
// The number of scales is the same as the number of blocks.
ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_d;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float8 sumf = 0.f;
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix*QK4_0 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0.f;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
yb += QK4_0 * (N_SIMDWIDTH/2);
}
float8 tot = (float8)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
if (first_row + 4 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
}
if (first_row + 5 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
}
if (first_row + 6 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
}
if (first_row + 7 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_q4_0_f32_8x_flat(
global uchar * src0_q,
global half * src0_d,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,254 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q4_0
//------------------------------------------------------------------------------
struct block_q4_0
{
half d;
uint8_t qs[QK4_0 / 2];
};
//
// This variant unrolls the loops and uses vector types instead of pointers.
// It improves performance on Adreno but not so much on Intel.
//
inline float block_q_4_0_dot_y_v(
global struct block_q4_0 * qb_curr,
float sumy,
float16 yl,
int il
) {
float d = qb_curr->d;
float acc = 0.f;
global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
acc += yl.s0 * (qs[0] & 0x000F);
acc += yl.s1 * (qs[0] & 0x0F00);
acc += yl.s8 * (qs[0] & 0x00F0);
acc += yl.s9 * (qs[0] & 0xF000);
acc += yl.s2 * (qs[1] & 0x000F);
acc += yl.s3 * (qs[1] & 0x0F00);
acc += yl.sa * (qs[1] & 0x00F0);
acc += yl.sb * (qs[1] & 0xF000);
acc += yl.s4 * (qs[2] & 0x000F);
acc += yl.s5 * (qs[2] & 0x0F00);
acc += yl.sc * (qs[2] & 0x00F0);
acc += yl.sd * (qs[2] & 0xF000);
acc += yl.s6 * (qs[3] & 0x000F);
acc += yl.s7 * (qs[3] & 0x0F00);
acc += yl.se * (qs[3] & 0x00F0);
acc += yl.sf * (qs[3] & 0xF000);
return d * (sumy * -8.f + acc);
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32_v(
global void * src0,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_0;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
// (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
// id of a SIMD group in the grid.
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl; // src1 vector cache
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_0 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il);
sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il);
sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il);
sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il);
// One thread in a SIMD group (i.e., subgroup) handles a half block,
// hence then entire SIMD group handles SIMDWIDTH/2 blocks.
// y points to the activation matrix (of type float). Therefore for
// one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
// SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
// floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
yb += QK4_0 * (N_SIMDWIDTH/2);
}
// The above does not work for Adreno - it produces incorrect results for
// row = 1, 2, 3 and only row = 0 gives the correct result.
// If N_DST is changed, the below array must be initialized accordingly.
// This also seems to perform better on Intel.
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mat_q4_0_f32_v(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,190 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_0 32
#define QR4_0 2
#define QK4_1 32
#define QR4_1 2
#define QK5_0 32
#define QR5_0 2
#define QK5_1 32
#define QR5_1 2
#define QK8_0 32
#define QR8_0 1
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
typedef char int8_t;
typedef uchar uint8_t;
typedef short int16_t;
typedef ushort uint16_t;
typedef int int32_t;
typedef uint uint32_t;
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
// 6-bit quantization
// weight is represented as x = a * q
// 16 blocks of 16 elements each
// Effectively 6.5625 bits per weight
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
half d; // super-block scale
} block_q6_K;
//------------------------------------------------------------------------------
// kernel_mul_mv_q6_K_f32
//------------------------------------------------------------------------------
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 1 // number of rows each SIMD group works on
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 1
#define N_SIMDGROUP 2
#define N_SIMDWIDTH 64
#endif
#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q6_K_f32(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
uchar kmask1 = 0x03;
uchar kmask2 = 0x0C;
uchar kmask3 = 0x30;
uchar kmask4 = 0xC0;
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int row = N_SIMDGROUP * r0 + get_sub_group_id();
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;
global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float sumf = 0;
// For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a
// block. Values in a subblock shares a scale that is quantized with 8 bits;
// the entire block shares a single floating point scale.
// For work distribution, each thread processes a subblock (16 weights), hence
// 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16
// (super) blocks -- this is the block stride.
// The 16 threads that process a (super) block are split into 2 portions, each has
// 8 threads; each portion works on 8 subblocks.
// For subgroup of 16 threads, the entire subgroup works on a single (super) block
// before moving to the next (super) block. Thread0 - thread7 work on the
// first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.
// Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on
// subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but
// works on a total of 16 weight values.
int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
int ip = tid/8; // first or second half of (super) block (0 or 1)
int il = tid%8; // each half has 8 parts, one per scale
int n = 4; // 4 scales at a time (and 4 sums)
int l0 = n*il; // offset into half-block, 0..28
int is = 8*ip + l0/16; // 0, 1, 8, 9
int y_offset = 128*ip + l0;
int q_offset_l = 64*ip + l0;
int q_offset_h = 32*ip + l0;
for (int i = ix; i < nb; i += BLOCK_STRIDE) {
global uint8_t * q1 = x[i].ql + q_offset_l;
global uint8_t * q2 = q1 + QK_K/8;
global uint8_t * qh = x[i].qh + q_offset_h;
global int8_t * sc = x[i].scales + is;
global float * y = yy + i * QK_K + y_offset;
float dall = x[i].d;
float4 sums = {0.f, 0.f, 0.f, 0.f};
sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);
sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);
sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);
sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);
sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);
sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);
sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);
sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);
sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);
sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);
sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);
sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);
sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);
sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);
sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);
sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);
sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
}
float tot = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
dst[r1*ne0 + im*ne0*ne1 + row] = tot;
}
}

View File

@ -0,0 +1,81 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// norm
//------------------------------------------------------------------------------
kernel void kernel_norm(
global void * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
float eps,
local float * sum
) {
src0 = (global void*)((global char*)src0 + offset0);
dst = (global void*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
// MEAN
// parallel sum
sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
sum[get_local_id(0)] += x[i00];
}
// reduce
barrier(CLK_LOCAL_MEM_FENCE);
for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
float mean = sum[0] / ne00;
// recenter and VARIANCE
barrier(CLK_LOCAL_MEM_FENCE);
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] - mean;
sum[get_local_id(0)] += y[i00] * y[i00];
}
// reduce
barrier(CLK_LOCAL_MEM_FENCE);
for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
float variance = sum[0] / ne00;
float scale = 1.0f/sqrt(variance + eps);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = y[i00] * scale;
}
}

View File

@ -0,0 +1,16 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// relu
//------------------------------------------------------------------------------
kernel void kernel_relu(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]);
}

View File

@ -0,0 +1,96 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// rms_norm
//------------------------------------------------------------------------------
// This kernel depends on subgroup size.
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_rms_norm(
global void * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
float eps,
local float * sum // Note, the size depends on number of subgroups
) {
src0 = (global void*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
global float * x_scalar = (global float *) x;
float4 sumf = 0;
float all_sum = 0;
// parallel sum
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
sumf += x[i00] * x[i00];
}
all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
all_sum = sub_group_reduce_add(all_sum);
if (get_sub_group_local_id() == 0) {
sum[get_sub_group_id()] = all_sum;
}
barrier(CLK_LOCAL_MEM_FENCE);
// broadcast
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
}
if (get_local_id(0) == 0) {
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
sum[0] += x_scalar[i];
}
sum[0] /= ne00;
}
barrier(CLK_LOCAL_MEM_FENCE);
const float mean = sum[0];
const float scale = 1.0f/sqrt(mean + eps);
global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global float * y_scalar = (global float *) y;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
y[i00] = x[i00] * scale;
}
if (get_local_id(0) == 0) {
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
y_scalar[i00] = x_scalar[i00] * scale;
}
}
}

View File

@ -0,0 +1,721 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// kernel_rope
//------------------------------------------------------------------------------
float rope_yarn_ramp(float low, float high, int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
float2 rope_yarn(
float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
return (float2)(cos(theta) * mscale, sin(theta) * mscale);
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
float2 rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
) {
// start and end correction dims
return (float2)(
max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
);
}
kernel void kernel_rope_norm_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
float theta_base = (float) pos[i2];
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
float theta = theta_base * pow(freq_base, inv_ndims*i0);
float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float x0 = src[0];
float x1 = src[1];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_norm_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
float theta_base = (float) pos[i2];
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
float theta = theta_base * pow(freq_base, inv_ndims*i0);
float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float x0 = src[0];
float x1 = src[1];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_neox_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
float theta_base = (float) pos[i2];
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_neox_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
float theta_base = (float) pos[i2];
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_multi_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_multi_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
if (i0 < n_dims) {
int ic = i0/2;
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
theta_base = pos[i2];
}
else if (sector >= sections.s0 && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1];
}
else if (sector >= sec_w && sector < sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 2];
}
else if (sector >= sec_w + sections.s2) {
theta_base = pos[i2 + ne2 * 3];
}
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
} else {
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
kernel void kernel_rope_vision_f32(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
int ic = i0/2;
const int sector = (i0/2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
const int p = sector;
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
} else if (sector >= sections.s0 && sector < sec_w) {
const int p = sector - sections.s0;
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
}
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
}
}
kernel void kernel_rope_vision_f16(
global void * src0,
ulong offset0,
global int * src1,
ulong offset1,
global float * src2,
ulong offset2,
global half * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
int ne1,
int ne2,
int ne3,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int n_past,
int n_dims,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
int4 sections
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
src2 = (global float*)((global char*)src2 + offset2);
dst = (global float*)((global char*)dst + offsetd);
int i3 = get_group_id(2);
int i2 = get_group_id(1);
int i1 = get_group_id(0);
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
global int * pos = src1;
const int sect_dims = sections.s0 + sections.s1;
const int sec_w = sections.s1 + sections.s0;
float inv_ndims = -1.f/n_dims;
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
int ic = i0/2;
const int sector = (i0/2) % sect_dims;
float theta_base = 0.0f;
if (sector < sections.s0) {
const int p = sector;
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
} else if (sector >= sections.s0 && sector < sec_w) {
const int p = sector - sections.s0;
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
}
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
}
}

View File

@ -0,0 +1,16 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// scale
//------------------------------------------------------------------------------
kernel void kernel_scale(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd,
float scale
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
}

View File

@ -0,0 +1,30 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// silu
//------------------------------------------------------------------------------
kernel void kernel_silu(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = x / (1.0f + exp(-x));
}
kernel void kernel_silu_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = x / (1.0f + exp(-x));
}

View File

@ -0,0 +1,87 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4_f16(
global float * src0,
ulong offset0,
global half * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global half *)((global char *)src1 + offset1);
dst = (global float *)((global char *)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
int h = i02;
float base = h < n_head_log2 ? m0 : m1;
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
}
float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
const float max = sub_group_reduce_max(lmax);
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
const float sum = sub_group_reduce_add(lsum);
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
pdst4[i00] /= sum;
}
}

View File

@ -0,0 +1,87 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4(
global float * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
int h = i02;
float base = h < n_head_log2 ? m0 : m1;
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
const float max = sub_group_reduce_max(lmax);
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
const float sum = sub_group_reduce_add(lsum);
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
pdst4[i00] /= sum;
}
}

View File

@ -0,0 +1,86 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_f16(
global float * src0,
ulong offset0,
global half * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global half *)((global char *)src1 + offset1);
dst = (global float *)((global char *)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
int h = i02;
float base = h < n_head_log2 ? m0 : m1;
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
float max = sub_group_reduce_max(lmax);
// parallel sum
float lsum = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// wish to compute it twice.
pdst[i00] = exp_psrc0;
}
const float sum = sub_group_reduce_add(lsum);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
pdst[i00] /= sum;
}
}

View File

@ -0,0 +1,86 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max(
global float * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
int h = i02;
float base = h < n_head_log2 ? m0 : m1;
int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = -INFINITY;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
float max = sub_group_reduce_max(lmax);
// parallel sum
float lsum = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// wish to compute it twice.
pdst[i00] = exp_psrc0;
}
const float sum = sub_group_reduce_add(lsum);
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
pdst[i00] /= sum;
}
}

View File

@ -0,0 +1,84 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
// 16-bit transpose, loading/storing a 4x4 tile of elements
kernel void kernel_transpose_16(
__read_only image1d_buffer_t input,
__write_only image1d_buffer_t output,
const uint rows,
const uint cols
) {
const int i = get_global_id(0);
const int j = get_global_id(1);
const int i_2 = i<<2;
const int j_2 = j<<2;
half4 temp0 = read_imageh(input, (j_2+0)*cols+i);
half4 temp1 = read_imageh(input, (j_2+1)*cols+i);
half4 temp2 = read_imageh(input, (j_2+2)*cols+i);
half4 temp3 = read_imageh(input, (j_2+3)*cols+i);
write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
}
// 32-bit transpose, loading/storing a 4x4 tile of elements
kernel void kernel_transpose_32(
__read_only image1d_buffer_t input,
__write_only image1d_buffer_t output,
const uint rows,
const uint cols
) {
const int i = get_global_id(0);
const int j = get_global_id(1);
const int i_2 = i<<2;
const int j_2 = j<<2;
float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
}
// 32-bit transpose, loading/storing a 4x4 tile of elements
// Only used for activations
// converts to FP16
// also adds zero padding for non multiple of 8 prompt lengths
kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
const int i = get_global_id(0);
const int j = get_global_id(1);
const int i_2 = i<<2;
const int j_2 = j<<2;
half4 temp0 = {0,0,0,0}; // initialize outputs to 0
half4 temp1 = {0,0,0,0};
half4 temp2 = {0,0,0,0};
half4 temp3 = {0,0,0,0};
if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
temp0 = read_imageh(input, (j_2+0)*cols+i);
}
if((j_2+1)*cols+i*4+3 < rows*cols*16){
temp1 = read_imageh(input, (j_2+1)*cols+i);
}
if((j_2+2)*cols+i*4+3 < rows*cols*16){
temp2 = read_imageh(input, (j_2+2)*cols+i);
}
if((j_2+3)*cols+i*4+3 < rows*cols*16){
temp3 = read_imageh(input, (j_2+3)*cols+i);
}
write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
}