Fix DMMV dequantization (llama/9279)

Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X
This commit is contained in:
Ouadie EL FAROUKI 2024-09-04 16:26:33 +01:00 committed by Georgi Gerganov
parent 3764bc974c
commit 1cecfe6a02

View File

@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
for (int mask = mask_start; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}