diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 6fafa528..6d568458 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3494,6 +3494,31 @@ typedef struct dpct_type_block_iq3_xxs { } block_iq3_xxs; static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); +#define QR3_XS 8 +#define QI3_XS (QK_K / (4*QR3_XS)) +#if QK_K == 64 +#define IQ3S_N_SCALE 2 +#else +#define IQ3S_N_SCALE QK_K/64 +#endif +typedef struct { + sycl::half d; + uint8_t qs[QK_K/4]; + uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[IQ3S_N_SCALE]; +} block_iq3_s; +static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); + +#define QR1_S 8 +#define QI1_S (QK_K / (4*QR1_S)) +typedef struct { + sycl::half d; + uint8_t qs[QK_K/8]; + uint8_t scales[QK_K/16]; +} block_iq1_s; +static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -4833,6 +4858,62 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res } +template +static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq3s_grid, + const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { + + const int i = item_ct1.get_group(2); + const block_iq3_s * x = (const block_iq3_s *) vx; + + const int tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * qs = x[i].qs + 8*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + qs[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + qs[2*il+1]); + const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); + const uint8_t signs = x[i].signs[4*ib + il]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +#else + assert(false); +#endif + +} + +template +static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq1s_grid, + const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { + + const int i = item_ct1.get_group(2); + const block_iq1_s * x = (const block_iq1_s *) vx; + + const int tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const int i8 = 4*ib+il; + uint8_t h = x[i].scales[i8/2] >> 4*(i8%2); + const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5))); + const float d = (float)x[i].d * (2*(h & 7) + 1); + for (int j = 0; j < 8; ++j) y[j] = d * grid[j]; +#else + assert(false); +#endif + +} + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -7679,6 +7760,76 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq, #endif } +static __dpct_inline__ float +vec_dot_iq3_s_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint32_t *iq3s_grid, const uint64_t *ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ + MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + + const int ib32 = iqs; + const uint8_t * qs = bq2->qs + 8*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + uint32_t signs0 = dpct::vectorized_binary( + ((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>()); + uint32_t signs1 = dpct::vectorized_binary( + ((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>()); + const int grid_l = dpct::vectorized_binary( + grid1[0] ^ signs0, signs0, std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid2[0] ^ signs1, signs1, std::minus<>()); + sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); + q8 += 8; + } + const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * bq8_1[ib32].ds[0]; + return d * sumi; +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + +static __dpct_inline__ float +vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint64_t *iq1s_grid, const uint64_t *ksigns64) { +#if QK_K == 256 + const block_iq1_s * bq1 = (const block_iq1_s *) vbq; + + const int ib32 = iqs; + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + const uint8_t h1 = bq1->scales[2*ib32+0]; + const uint8_t h2 = bq1->scales[2*ib32+1]; + const int * q8 = (const int *)bq8_1[ib32].qs; + const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); + const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); + const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5))); + const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1))); + for (int j = 0; j < 2; ++j) { + sumi1 = dpct::dp4a(q8[j+0], grid1[j], sumi1); + sumi2 = dpct::dp4a(q8[j+2], grid2[j], sumi2); + sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3); + sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4); + } + const float d = (float)bq1->d * bq8_1[ib32].ds[0]; + return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + + sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); +#else + assert(false); + return 0.f; +#endif +} template +static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + template static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> &item_ct1) { @@ -10129,6 +10372,64 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, } } +template +static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb = k / QK_K; + { + iq3s_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq3_s( + vx, y, item_ct1, iq3s_grid_ptr_ct1, + ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); + }); + }); + } +} + +template +static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb = k / QK_K; + { + iq1s_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq1_s( + vx, y, item_ct1, iq1s_grid_ptr_ct1, + ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); + }); + }); + } +} + template static void convert_unary_sycl(const void *__restrict__ vx, dst_t *__restrict__ y, const int k, @@ -10179,6 +10480,10 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try { return dequantize_row_iq2_xs_sycl; case GGML_TYPE_IQ3_XXS: return dequantize_row_iq3_xxs_sycl; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_sycl; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_sycl; case GGML_TYPE_F32: return convert_unary_sycl; default: @@ -10219,6 +10524,10 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) { return dequantize_row_iq2_xs_sycl; case GGML_TYPE_IQ3_XXS: return dequantize_row_iq3_xxs_sycl; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_sycl; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_sycl; case GGML_TYPE_F16: return convert_unary_sycl; default: @@ -10808,6 +11117,61 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, } } +static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3s_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq3_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1, + iq3s_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq1s_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq1_s_q8_1( + vx, vy, dst, ncols, nrows, item_ct1, + iq1s_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols_x, @@ -13556,8 +13920,11 @@ static int64_t get_row_rounding(ggml_type type, const std::array= VER_GEN9 ? 128 : 64; + case GGML_TYPE_IQ3_S: + return max_compute_capability >= VER_GEN9 ? 128 : 64; case GGML_TYPE_Q6_K: return 64; default: @@ -13618,6 +13985,12 @@ inline void ggml_sycl_op_mul_mat_vec_q( case GGML_TYPE_IQ3_XXS: mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; default: GGML_ASSERT(false); break; @@ -16963,9 +17336,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons return false; } ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) { + if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ2_S || + a_type == GGML_TYPE_IQ4_XS) { return false; } return true;