ggml : fix I8MM Q4_1 scaling factor conversion (llama/10562)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-29 16:25:39 +02:00
parent cb847c20a7
commit 3623bd58f2
2 changed files with 36 additions and 25 deletions

View File

@ -1791,11 +1791,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), float32_t _scale[4] = {
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
};
float32x4_t scale = vld1q_f32(_scale); float32x4_t scale = vld1q_f32(_scale);
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@ -2347,10 +2348,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const block_q8_1 * restrict b_y0 = &vy0[i]; const block_q8_1 * restrict b_y0 = &vy0[i];
const block_q8_1 * restrict b_y1 = &vy1[i]; const block_q8_1 * restrict b_y1 = &vy1[i];
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), float32_t summs_t[4] = {
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)}; GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
};
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
const uint8x16_t m4b = vdupq_n_u8(0x0F); const uint8x16_t m4b = vdupq_n_u8(0x0F);
@ -2371,10 +2374,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
// mmla into int32x4_t // mmla into int32x4_t
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, float32_t _scale[4] = {
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
};
float32x4_t scale = vld1q_f32(_scale); float32x4_t scale = vld1q_f32(_scale);
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@ -2392,12 +2397,14 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
l1, r1)), l2, r2)), l3, r3))), scale); l1, r1)), l2, r2)), l3, r3))), scale);
} }
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
sumv2 = vaddq_f32(sumv2, summs0); sumv2 = vaddq_f32(sumv2, summs0);
vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s, vget_low_f32 (sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2));
return; return;
} }
#endif #endif
@ -3374,10 +3381,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), float32_t _scale[4] = {
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
};
float32x4_t scale = vld1q_f32(_scale); float32x4_t scale = vld1q_f32(_scale);
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@ -3395,11 +3404,13 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale); l1, r1)), l2, r2)), l3, r3))), scale);
} }
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
vst1_f32(s, vget_low_f32(sumv2)); vst1_f32(s, vget_low_f32 (sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2));
return; return;
} }
#endif #endif

View File

@ -7641,8 +7641,8 @@ UseGgmlGemm2:;
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
int64_t num_rows_per_vec_dot = vec_dot_num_rows; int64_t num_rows_per_vec_dot = vec_dot_num_rows;
// TODO: currently the mmla kernels support only even numbered rows/cols. // these checks are needed to avoid crossing dim1 boundaries
// this check can be removed once they are extended to support odd numbered rows/cols too // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
num_rows_per_vec_dot = 1; num_rows_per_vec_dot = 1;
} }