From 4737a8c7802b5b998876a757cdae95fb0ab0c088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Scipione?= Date: Mon, 9 Jun 2025 11:47:07 +0200 Subject: [PATCH] sycl: Add reorder to Q6_K mmvq implementation (llama/13885) * Add Reorder to Q6_K mmvq implementation * Address PR comments: clean up comments * Remove unused parameter after refactoring q4_k * Adding inline to function and removing unnecessary reference to int --------- Signed-off-by: nscipione --- ggml/src/ggml-sycl/convert.cpp | 23 ++++++++- ggml/src/ggml-sycl/dequantize.hpp | 32 ++++++++++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 52 ++++++++++++++++++- ggml/src/ggml-sycl/mmvq.cpp | 36 +++++++++++--- ggml/src/ggml-sycl/quants.hpp | 48 ++++++++++++++---- ggml/src/ggml-sycl/vecdotq.hpp | 83 +++++++++++++++++++++++++++---- 6 files changed, 244 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 75bac98e..96d2583b 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template +static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); +} + template static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q6_K_sycl_reorder; + } else { + return dequantize_row_q6_K_sycl; + } case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_sycl; case GGML_TYPE_IQ1_M: @@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q6_K_sycl_reorder; + } else { + return dequantize_row_q6_K_sycl; + } case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_sycl; case GGML_TYPE_IQ1_M: diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 64e92f73..540539bb 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { + const int64_t ib = item_ct1.get_group(2); + + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid / 32; // ip is 0 or 1 + const int64_t il = tid - 32 * ip; // 0...32 + const int64_t is = 8 * ip + il / 16; + + const uint8_t * base_ptr = static_cast(vx); + const auto ql_offset = ib * (QK_K / 2); + const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib; + const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib; + const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks; + const uint8_t * ql_ptr = base_ptr + ql_offset; + const uint8_t * qh_ptr = base_ptr + qh_offset; + const uint8_t * scales_ptr = base_ptr + base_scales_offset; + const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib; + + dst_t * y = yy + ib * QK_K + 128 * ip + il; + + const uint8_t * ql = ql_ptr + 64 * ip + il; + const uint8_t qh = *(qh_ptr + 32 * ip + il); + const int8_t * sc = reinterpret_cast(scales_ptr + is); + + y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + template static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3936f1ea..3693b0a4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) { + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && + !g_ggml_sycl_disable_optimize) { ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; tensor->extra = extra; ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. @@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { case GGML_TYPE_Q4_0: return true; case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: return !g_ggml_sycl_prioritize_dmmv; default: return false; @@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: return true; default: return false; @@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d sycl::free(tmp_buf, *stream); } +static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q6_K) == 0); + GGML_ASSERT(offset % sizeof(block_q6_K) == 0); + + const int nblocks = size / sizeof(block_q6_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * ql_ptr = data_device; + auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks; + auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks; + sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks); + + stream + ->parallel_for(nblocks, + [=](auto i) { + const block_q6_K * x = (const block_q6_K *) tmp_buf; + const int ib = i; + + const uint8_t * ql = x[ib].ql; + const uint8_t * qh = x[ib].qh; + uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib; + uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib; + uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib; + + for (int j = 0; j < QK_K / 2; ++j) { + base_ql_ptr[j] = ql[j]; + } + for (int j = 0; j < QK_K / 4; ++j) { + base_qh_ptr[j] = qh[j]; + } + + for (int j = 0; j < QK_K / 16; ++j) { + base_scales_ptr[j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].d; + }) + .wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; @@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { case GGML_TYPE_Q4_K: reorder_qw_q4_k(data_device, size, 0, stream); break; + case GGML_TYPE_Q6_K: + reorder_qw_q6_k(data_device, size, 0, stream); + break; default: GGML_ABORT("reorder_qw() called with unsupported type"); break; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 80c780b2..5b7f0640 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r float partial_sum = 0.0f; for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { - const int ibx = row * blocks_per_row + i; // x block index - // TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits - const int bx_offset = block_type::get_block_offset(ibx); - const int d_offset = block_type::get_d_offset(nrows, ncols, ibx); + const int ibx = row * blocks_per_row + i; // x block index + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); // Y block index that aligns with ibx const int iby = i * block_type::block_to_q8_1_ratio(); const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; @@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r // x block quant index when casting the quants to int const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); - partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks); + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); } } @@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_IQ1_S: mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 88ec13ea..8b952db4 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -14,12 +14,13 @@ #ifndef GGML_SYCL_QUANTS_HPP #define GGML_SYCL_QUANTS_HPP +#include + #include "ggml-common.h" #include "ggml.h" namespace ggml_sycl_reordered { - // The reordered block moves quants (qs) and scales(d) to two // uniform regions of memory that is contiguous in the same tensor. // What this means is that instead of having: @@ -32,7 +33,6 @@ namespace ggml_sycl_reordered { template struct block_q_t; - // qk number of weights / quants in a block // qr number of weights in a byte (described as 'before dequantization') // for quantization types that has low and high bits split, qr is calculated with @@ -47,10 +47,12 @@ template <> struct block_q_t { static constexpr uint32_t vdr_mmvq = 2; }; - static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } - static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { - return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half); + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } @@ -64,20 +66,46 @@ template <> struct block_q_t { static constexpr uint32_t vdr_mmvq = 2; }; - static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); } + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } - static constexpr int get_d_offset(int nrows, int ncols, const int block_index) { + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { auto nblocks = (nrows * (ncols / traits::qk)); - return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)); + return { nblocks * (QK_K / 2), + (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } - - constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI6_K; + static constexpr uint32_t qr = QR6_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { + auto low_bits_index = block_index * (traits::qk / traits::qr); + // the index of high bits it's after all low bits + auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); + return { low_bits_index, high_bits_index }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); + auto block_scales = total_qs_bytes + block_index * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + return { block_scales, sb_scale }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index fa258e4d..0a5d4999 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -284,10 +284,11 @@ template <> struct reorder_vec_dot_q_sycl { return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y()); } - __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) { - const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset; - const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset)); + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset.first; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); int v[q4_0_traits::vdr_mmvq]; int u[2 * q4_0_traits::vdr_mmvq]; @@ -346,15 +347,15 @@ template <> struct reorder_vec_dot_q_sycl { using q4_k_block = ggml_sycl_reordered::block_q_t; using q4_k_traits = typename q4_k_block::traits; - float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset, - const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) { - const int ib = ibx_offset / (QK_K / 2); + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const int ib = ibx_offset.first / (QK_K / 2); const uint8_t * base = static_cast(vbq); - const uint8_t * qs = base + ibx_offset; - const int total_qs_bytes = nblocks * (QK_K / 2); - const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE; - const ggml_half2 * dms = reinterpret_cast(base + d_offset); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE; + const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); @@ -395,6 +396,66 @@ template <> struct reorder_vec_dot_q_sycl { } }; +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q6_K; + + using q6_k_block = ggml_sycl_reordered::block_q_t; + using q6_k_traits = typename q6_k_block::traits; + + __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, + const int8_t * __restrict__ scales, const float d, + const float * __restrict__ d8) { + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4 * i]; + + const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; + + const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, + dpct::sub_sat()); // vi = (vil | vih) - 32 + + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d * sumf; + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, + const int iqs) { + const int ib = ibx_offset.first / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * ql = base + ibx_offset.first; + const uint8_t * qh = base + ibx_offset.second; + const int8_t * scales = reinterpret_cast(base + d_offset.first); + const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4); + const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8); + const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4)); + + const int vl = get_int_from_uint8(ql, iqs); + const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift; + + const int8_t * scs = scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); + const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); + d8[i] = ds_values[0]; + } + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + } +}; #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4