Minor arithmetic improvement to mmvq wrapper kernel (llama/7172)

This commit is contained in:
Ouadie EL FAROUKI 2024-05-10 01:32:15 +01:00 committed by Georgi Gerganov
parent c114b75aee
commit fe454b8d9e

View File

@ -8330,24 +8330,26 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
const int blocks_per_row = ncols / qk; const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread const int qi_vdr = (qi / vdr); // N_threads processing 1 qk block
// partial sum for each thread
float tmp = 0.0f; float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx; const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy; const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; for (int i = item_ct1.get_local_id(2) / qi_vdr; i < blocks_per_row;
i += blocks_per_warp) { i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i; // x block index const int ibx = row * blocks_per_row + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs = const int iqs =
vdr * vdr *
(item_ct1.get_local_id(2) % (item_ct1.get_local_id(2) -
(qi / vdr)); // x block quant index when casting the quants to int i * qi_vdr); // x block quant index when casting the quants to int
tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs);
} }
// sum up partial sums and write back result // sum up partial sums and write back result