mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-09 04:15:15 +00:00
vulkan: workaround for AMD Windows driver 16 bit unpack8 bug (llama/12472)
This commit is contained in:
parent
fa2b5249ff
commit
2f77a9e9bd
@ -82,8 +82,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
|
||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
|
||||
return vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
#endif
|
||||
|
@ -19,8 +19,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const float db = d * (0.5 + scale) * 0.25;
|
||||
|
||||
const uint qh = data_a[ibi].qh[ib32];
|
||||
const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]);
|
||||
const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]);
|
||||
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
|
||||
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
|
||||
[[unroll]] for (uint l = 0; l < 2; ++l) {
|
||||
const uint8_t sign = sign16[l];
|
||||
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);
|
||||
|
@ -21,7 +21,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
|
||||
sum[j] = 0.0;
|
||||
}
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]);
|
||||
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
|
||||
const uint sign = data_a[ibi].signs[4 * ib32 + l];
|
||||
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
|
||||
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));
|
||||
|
@ -336,8 +336,8 @@ void main() {
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
|
||||
const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
|
||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
@ -544,7 +544,7 @@ void main() {
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@ -564,7 +564,7 @@ void main() {
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@ -586,7 +586,7 @@ void main() {
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@ -611,7 +611,7 @@ void main() {
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
@ -631,7 +631,7 @@ void main() {
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
|
Loading…
x
Reference in New Issue
Block a user