mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-03-01 03:26:12 +00:00
metal : optimize dequant q6_K kernel (llama/11892)
This commit is contained in:
parent
a7fc1038ca
commit
4444db7360
@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
|
||||
template <typename type4x4>
|
||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d_all = xb->d;
|
||||
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
||||
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
||||
device const uint16_t * ql = (device const uint16_t *)xb->ql;
|
||||
device const uint16_t * qh = (device const uint16_t *)xb->qh;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
|
||||
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
||||
qh = qh + 32*(il/8) + 16*(il&1);
|
||||
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
|
||||
qh = qh + 16*(il/8) + 8*(il&1);
|
||||
float sc = scales[(il%2) + 2 * ((il/2))];
|
||||
il = (il/2) & 3;
|
||||
|
||||
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
||||
const float coef = il>1 ? 1.f/16.f : 1.f;
|
||||
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
|
||||
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
|
||||
const float ml = d_all * sc * 32.f;
|
||||
const float dl = d_all * sc * coef;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
|
||||
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
|
||||
reg[i/4][i%4] = dl * q - ml;
|
||||
const float dl0 = d_all * sc;
|
||||
const float dl1 = dl0 / 256.f;
|
||||
const float dl2 = dl0 / (256.f * 256.f);
|
||||
const float dl3 = dl0 / (256.f * 256.f * 256.f);
|
||||
const uint8_t shr_h = il>2 ? 2 : 0;
|
||||
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
|
||||
const uint8_t shr_l = il>1 ? 4 : 0;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
|
||||
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
|
||||
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
|
||||
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
|
||||
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
|
||||
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
|
||||
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user