mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-03-01 03:26:12 +00:00
ggml: aarch64: implement SVE kernels for q3_K_q8_K vector dot (llama/11917)
* Added SVE Implementation for Q3_K Kernel in ggml-cpu-quants.c file * Improved Formating of code in ggml-cpu-quants.c file * style : minor fixes * style : less whitespaces * style : ptr spaceing --------- Co-authored-by: vithulep <p.m.vithule1517@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
51a3580c79
commit
64a430bc81
@ -5112,7 +5112,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef __ARM_NEON
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
|
|
||||||
|
uint32_t utmp[4];
|
||||||
|
|
||||||
|
const int8_t m32 = 32;
|
||||||
|
const int vector_length = svcntb()*8;
|
||||||
|
const svuint8_t m3b_sv = svdup_n_u8(0x3);
|
||||||
|
const svint32_t vzero_sv = svdup_n_s32(0);
|
||||||
|
|
||||||
|
const svuint8_t m0_sv = svdup_n_u8(1);
|
||||||
|
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
|
||||||
|
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
|
||||||
|
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
|
||||||
|
svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
|
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
||||||
|
|
||||||
|
const uint8_t * restrict q3_sv = x[i].qs;
|
||||||
|
const uint8_t * restrict qh_sv = x[i].hmask;
|
||||||
|
const int8_t * restrict q8_sv = y[i].qs;
|
||||||
|
|
||||||
|
// Set up scales
|
||||||
|
uint32_t * aux = &x[i].scales;
|
||||||
|
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
||||||
|
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
||||||
|
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
||||||
|
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
||||||
|
|
||||||
|
int8_t * scale = (int8_t *)utmp;
|
||||||
|
|
||||||
|
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
||||||
|
|
||||||
|
switch (vector_length) {
|
||||||
|
case 128:
|
||||||
|
{
|
||||||
|
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
|
||||||
|
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
|
||||||
|
svuint8_t q3h_sv;
|
||||||
|
|
||||||
|
svint32_t sumi1_1 = svdup_n_s32(0);
|
||||||
|
svint8_t q3bytes_sv;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
|
const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
||||||
|
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
||||||
|
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
|
||||||
|
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
||||||
|
|
||||||
|
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
||||||
|
|
||||||
|
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
|
||||||
|
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
||||||
|
|
||||||
|
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
||||||
|
|
||||||
|
|
||||||
|
scale += 4;
|
||||||
|
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
|
||||||
|
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
||||||
|
|
||||||
|
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
||||||
|
|
||||||
|
|
||||||
|
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
||||||
|
|
||||||
|
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
||||||
|
|
||||||
|
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
||||||
|
|
||||||
|
if (j == 0) {
|
||||||
|
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
|
||||||
|
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
scale += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
|
||||||
|
} break;
|
||||||
|
case 256:
|
||||||
|
case 512:
|
||||||
|
{
|
||||||
|
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
|
||||||
|
svuint8_t q3h_sv;
|
||||||
|
|
||||||
|
svint32_t sumi1_1 = svdup_n_s32(0);
|
||||||
|
svint8_t q3bytes_sv;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
|
const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
|
||||||
|
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||||
|
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||||
|
|
||||||
|
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
|
||||||
|
svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
||||||
|
|
||||||
|
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
||||||
|
|
||||||
|
scale += 4;
|
||||||
|
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||||
|
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
||||||
|
|
||||||
|
q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
||||||
|
|
||||||
|
q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
|
||||||
|
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
||||||
|
|
||||||
|
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
||||||
|
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
||||||
|
|
||||||
|
if (j == 0) {
|
||||||
|
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
scale += 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported vector length");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*s = sum;
|
||||||
|
|
||||||
|
#elif __ARM_NEON
|
||||||
|
|
||||||
uint32_t aux[3];
|
uint32_t aux[3];
|
||||||
uint32_t utmp[4];
|
uint32_t utmp[4];
|
||||||
|
Loading…
x
Reference in New Issue
Block a user