mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 04:37:51 +00:00
ggml : fix I8MM Q4_1 scaling factor conversion (llama/10562)
ggml-ci
This commit is contained in:
parent
cb847c20a7
commit
3623bd58f2
@ -1791,11 +1791,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||||||
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
||||||
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
||||||
|
|
||||||
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
float32_t _scale[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
|
||||||
|
};
|
||||||
float32x4_t scale = vld1q_f32(_scale);
|
float32x4_t scale = vld1q_f32(_scale);
|
||||||
|
|
||||||
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
||||||
@ -1811,7 +1812,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||||||
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
|
|
||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
@ -2347,10 +2348,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||||||
const block_q8_1 * restrict b_y0 = &vy0[i];
|
const block_q8_1 * restrict b_y0 = &vy0[i];
|
||||||
const block_q8_1 * restrict b_y1 = &vy1[i];
|
const block_q8_1 * restrict b_y1 = &vy1[i];
|
||||||
|
|
||||||
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
|
float32_t summs_t[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
|
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
|
||||||
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
|
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
|
||||||
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
|
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
|
||||||
|
};
|
||||||
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
|
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
|
||||||
|
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
||||||
@ -2371,10 +2374,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||||||
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
||||||
|
|
||||||
// mmla into int32x4_t
|
// mmla into int32x4_t
|
||||||
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
|
float32_t _scale[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
|
||||||
|
};
|
||||||
float32x4_t scale = vld1q_f32(_scale);
|
float32x4_t scale = vld1q_f32(_scale);
|
||||||
|
|
||||||
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
||||||
@ -2389,15 +2394,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
|
|||||||
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
|
|
||||||
sumv2 = vaddq_f32(sumv2, summs0);
|
sumv2 = vaddq_f32(sumv2, summs0);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32 (sumv2));
|
vst1_f32(s, vget_low_f32 (sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -3374,10 +3381,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||||||
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
||||||
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
||||||
|
|
||||||
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
float32_t _scale[4] = {
|
||||||
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
|
||||||
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
|
||||||
|
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
|
||||||
|
};
|
||||||
float32x4_t scale = vld1q_f32(_scale);
|
float32x4_t scale = vld1q_f32(_scale);
|
||||||
|
|
||||||
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
||||||
@ -3393,13 +3402,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||||||
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
||||||
|
|
||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
|
|
||||||
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32(sumv2));
|
vst1_f32(s, vget_low_f32 (sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -7641,8 +7641,8 @@ UseGgmlGemm2:;
|
|||||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||||
|
|
||||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
// these checks are needed to avoid crossing dim1 boundaries
|
||||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
// can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
|
||||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
||||||
num_rows_per_vec_dot = 1;
|
num_rows_per_vec_dot = 1;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user