diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index de814ef9..f78a36dd 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -14,23 +14,24 @@ #define GGML_SYCL_BACKEND_HPP #include "binbcast.hpp" -#include "concat.hpp" #include "common.hpp" +#include "concat.hpp" #include "conv.hpp" #include "convert.hpp" +#include "cpy.hpp" #include "dequantize.hpp" #include "dmmv.hpp" +#include "element_wise.hpp" +#include "gla.hpp" +#include "im2col.hpp" #include "mmq.hpp" #include "mmvq.hpp" -#include "rope.hpp" #include "norm.hpp" +#include "outprod.hpp" +#include "quants.hpp" +#include "rope.hpp" #include "softmax.hpp" #include "tsembd.hpp" -#include "im2col.hpp" #include "wkv.hpp" -#include "outprod.hpp" -#include "element_wise.hpp" -#include "cpy.hpp" -#include "gla.hpp" -#endif // GGML_SYCL_BACKEND_HPP +#endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 69aad938..60909dde 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr); extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; +extern int g_ggml_sycl_prioritize_dmmv; #define GGML_SYCL_DEBUG(...) \ do { \ diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 68a26fa4..0ea72994 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -49,6 +49,7 @@ static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_prioritize_dmmv = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -195,11 +196,13 @@ static void ggml_check_sycl() try { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); + GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); @@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons std::exit(1); } +enum class mul_mat_algo { + DMMV = 0, + MMVQ = 1, + MUL_MAT_SYCL = 2, +}; + inline bool ggml_sycl_supports_mmq(enum ggml_type type) { // TODO: accuracy issues in MMQ GGML_UNUSED(type); return false; } +inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + static bool ggml_sycl_supports_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows, GGML_ASSERT((size % sizeof(block_q4_0) == 0)); GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); int offset_blks = offset / sizeof(block_q4_0); - auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;; + auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2; auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; stream->parallel_for( @@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { reorder_qw(data_device, ncols, nrows, size, 0, stream); } -/* -* This function could be called when the OP (mul_mat) function support reorder optimizition. -*/ -static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, - ggml_tensor * dst) { - if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT - ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. - dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - src0->type == GGML_TYPE_Q4_0 && - src1->ne[2]==1 && src1->ne[3]==1) { +static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { + return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT + ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. + dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. + dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; +} - ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra; - if (!extra) return; //only happen in CI/UT permute case. - - if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder. - - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; //used to decode/dequan in next steps. +static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, + ggml_tensor * dst, mul_mat_algo mm_algorithm) { + if (!should_reorder_tensor(*ctx, dst)) { + return; } + + ggml_tensor_extra_gpu * extra = static_cast(src0->extra); + if (!extra || extra->optimized_feature.reorder) { + return; // Skip permutations and already reordered tensors + } + + switch (mm_algorithm) { + case mul_mat_algo::DMMV: + if (!ggml_sycl_supports_reorder_dmmv(src0->type)) { + return; + } + break; + case mul_mat_algo::MMVQ: + if (!ggml_sycl_supports_reorder_mmvq(src0->type)) { + return; + } + break; + case mul_mat_algo::MUL_MAT_SYCL: + if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) { + return; + } + break; + } + + reorder_qw(src0, ctx->stream()); + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering } static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor int64_t min_compute_capability = INT_MAX; if (split) { - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = + (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; for (int id = 0; id < ggml_sycl_info().device_count; ++id) { // skip devices that are not going to do any work: @@ -2924,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } } } else { - min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; + min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; } // check data types and tensor shapes for custom matrix multiplication kernels: @@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX + // mmvq path is faster in the CUDA backend. - if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda) + if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // TODO: Refactor and cleanup of mul mat dispatching. @@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder. - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream()); + constexpr bool convert_src1_to_q8_1 = false; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1); } else if (use_mul_mat_vec_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + constexpr bool convert_src1_to_q8_1 = true; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1); } else if (use_mul_mat_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true); + constexpr bool convert_src1_to_q8_1 = true; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1); } else { - opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder. - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + constexpr bool convert_src1_to_q8_1 = false; + // MUL_MAT_SYCL supports reorder + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); } + GGML_SYCL_DEBUG("call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 1b92ba2d..3cade1a4 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,6 +1,60 @@ #include "mmvq.hpp" + +#include "ggml.h" +#include "common.hpp" +#include "quants.hpp" #include "vecdotq.hpp" -#include + +template +static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + const block_q8_1 * y = (const block_q8_1 *) vy; + + float partial_sum = 0.0f; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; // x block index + // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits + const int bx_offset = block_type::get_block_offset(ibx); + const int d_offset = block_type::get_d_offset(nrows, ncols, ibx); + + // Y block index that aligns with ibx + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + // x block quant index when casting the quants to int + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs); + } + } + + auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>()); + + if (sg.leader()) { + dst[row] = sum; + } +} template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -480,26 +534,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, } } -static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, +static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { - - stream->submit([&](sycl::handler &cgh) { - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -916,93 +983,95 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } -void ggml_sycl_op_mul_mat_vec_q( - ggml_backend_sycl_context & ctx, - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_col_size, - const dpct::queue_ptr &stream) { - +void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, + const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size, + const dpct::queue_ptr & stream) { const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); - const int64_t ne00 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - for (int i = 0; i < src1_ncols; i++) - { + for (int i = 0; i < src1_ncols; i++) { const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; - const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; - float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; + float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); + case GGML_TYPE_Q4_0: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + default: + GGML_ABORT("fatal error"); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp new file mode 100644 index 00000000..a74e3052 --- /dev/null +++ b/ggml/src/ggml-sycl/quants.hpp @@ -0,0 +1,61 @@ +// +// MIT license +// Copyright (C) 2025 Codeplay Software Ltd. +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_QUANTS_HPP +#define GGML_SYCL_QUANTS_HPP + +#include "ggml-common.h" +#include "ggml.h" + +namespace ggml_sycl_reordered { + + +// The reordered block moves quants (qs) and scales(d) to two +// uniform regions of memory that is contiguous in the same tensor. +// What this means is that instead of having: +// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN] +// We have: +// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] +// +// Notes: out-of-bounds qs will run into d values +// Aligment relies on the allocated size of qs + +template struct block_q_t; + + +// qk number of weights / quants in a block +// qr number of weights in a byte (described as 'before dequantization') +// for quantization types that has low and high bits split, qr is calculated with +// using the lower bits, e.g for Q6 quants QR6 is 2 +// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field) +// See ggml-common.h to see how these are calculated +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK4_0; + static constexpr uint32_t qi = QI4_0; + static constexpr uint32_t qr = QR4_0; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + + static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half); + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + +} // namespace ggml_sycl_reordered + +#endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index c5942008..cbf664fc 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: MIT // @@ -14,8 +14,11 @@ #define GGML_SYCL_VECDOTQ_HPP #include "dpct/helper.hpp" +#include "ggml.h" +#include "quants.hpp" -typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs); static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -252,13 +255,60 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +template struct reorder_vec_dot_q_sycl { + static_assert(T != T, "ggml_type for reorder vecdot not implemented"); +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_0; + + using q4_0_block = ggml_sycl_reordered::block_q_t; + using q4_0_traits = typename q4_0_block::traits; + + __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y()); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, + const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); + int v[q4_0_traits::vdr_mmvq]; + int u[2 * q4_0_traits::vdr_mmvq]; + +#pragma unroll + + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_uint8(bq4_0, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi); + } + + return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds); + }; +}; + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 template -static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, - const float &d4, - const sycl::half2 &ds8) { +static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, + const sycl::half2 & ds8) { int sumi = 0; #pragma unroll for (int i = 0; i < vdr; ++i) { @@ -270,8 +320,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); } - const sycl::float2 ds8f = - ds8.convert(); + const sycl::float2 ds8f = ds8.convert(); // second part effectively subtracts 8 from each quant value return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); @@ -456,13 +505,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq, const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; int v[VDR_Q4_0_Q8_1_MMVQ]; - int u[2*VDR_Q4_0_Q8_1_MMVQ]; + int u[2 * VDR_Q4_0_Q8_1_MMVQ]; #pragma unroll for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); } return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds);