diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 78c3bddf..0eba3742 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -323,15 +323,16 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo const uint8_t qs = bl.block.qs[iqs]; const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); - const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28)); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); sign |= bitCount(sign) << 7; - const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3]; + uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); - float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); - - return ret; + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); } #endif @@ -350,14 +351,16 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor const uint iqs = (idx & 0xF8) >> 3; // 0..63 const uint16_t qs = bl.block.qs[iqs]; - const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF)); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); uint sign = uint(qs >> 9); sign |= bitCount(sign) << 7; - const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3]; + uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); - float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); - return ret; + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); } #endif @@ -369,24 +372,23 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2 float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { uint idx = coordInBlock[1]; - uint lsb = idx & 1; - idx /= 2; - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0xF8) >> 3; // 0..31 + const uint qhshift = 2 * (ib8 % 4); - const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; const uint qs = bl.block.qs[ib8]; const uint qh = bl.block.qh[ib32]; - const uint qhshift = 2 * (ib8 % 4); - const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6); const float d = float(bl.block.d); const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); - return float16_t(v[lsb]); + const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign)); + uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); + return float16_t(v[idx & 1]); } #endif @@ -401,28 +403,25 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3 float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); uint idx = coordInBlock[1]; - uint lsb = idx & 1; - idx /= 2; - const uint iqs = (idx % 128) / 2; // 0..63 - const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values const float d = float(bl.block.d); const uint qs = bl.block.qs[iqs]; - const uint signs = pack32(u8vec4( - bl.block.qs[is+0], - bl.block.qs[is+1], - bl.block.qs[is+2], - bl.block.qs[is+3] + const uint signs = pack32(u16vec2( + bl16.block.qs[is/2+0], + bl16.block.qs[is/2+1] )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1)); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - return float16_t(v[lsb]); + return float16_t(v[idx & 1]); } #endif @@ -434,23 +433,21 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3 float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { uint idx = coordInBlock[1]; - uint lsb = idx & 1; - idx /= 2; - const uint iqs = (idx % 128) / 2; // 0..63 - const uint iqh = iqs / 8; + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; const float d = float(bl.block.d); const uint qs = bl.block.qs[iqs]; const uint qh = bl.block.qh[iqh]; - const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6)); const uint scale = bl.block.scales[iqs / 16]; - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3); const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); - return float16_t(v[lsb]); + return float16_t(v[idx & 1]); } #endif