ggml : add SVE support for q6_K_q8_K (llama/12361)

This commit is contained in:
fj-y-saito 2025-03-18 17:14:39 +09:00 committed by Georgi Gerganov
parent fa72479cfb
commit 9993c3f703

View File

@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#ifdef __ARM_NEON
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);
svuint8_t mone = svdup_n_u8(0x30);
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
for (int i = 0; i < nb; ++i) {
const float d_all = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const int8_t * GGML_RESTRICT scale = x[i].scales;
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
const svint64_t prod = svdup_n_s64(0);
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
svdot_s64(prod, q8sums_2, q6scales_2)));
int32_t isum = 0;
switch (vector_length) {
case 128:
{
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
q8bytes_1 = svld1_s8(pg8_16, q8);
q8bytes_2 = svld1_s8(pg8_16, q8+16);
q8bytes_3 = svld1_s8(pg8_16, q8+32);
q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
}
isum += svaddv_s32(pg32_4, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
case 256:
case 512:
{
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; j++) {
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
q8 += 128;
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
scale += 8;
}
isum += svaddv_s32(pg32_8, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sum;
#elif __ARM_NEON
float sum = 0;
const uint8x16_t m4b = vdupq_n_u8(0xF);