sycl: Add reorder to Q6_K mmvq implementation (llama/13885)

* Add Reorder to Q6_K mmvq implementation

* Address PR comments: clean up comments

* Remove unused parameter after refactoring q4_k

* Adding inline to function and removing unnecessary reference to int

---------

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
Nicolò Scipione
2025-06-09 11:47:07 +02:00
committed by Georgi Gerganov
parent 8a70f4d18b
commit 4737a8c780
6 changed files with 244 additions and 30 deletions

View File

@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
#endif
}
template <typename dst_t>
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
const int64_t nb = k / QK_K;
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
}
template <typename dst_t>
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
case GGML_TYPE_Q6_K:
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
} else {
return dequantize_row_q6_K_sycl;
}
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ1_M:
@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
case GGML_TYPE_Q6_K:
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
} else {
return dequantize_row_q6_K_sycl;
}
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ1_M:

View File

@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
#endif
}
template <typename dst_t>
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
const int64_t ib = item_ct1.get_group(2);
const int64_t tid = item_ct1.get_local_id(2);
const int64_t ip = tid / 32; // ip is 0 or 1
const int64_t il = tid - 32 * ip; // 0...32
const int64_t is = 8 * ip + il / 16;
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
const auto ql_offset = ib * (QK_K / 2);
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
const uint8_t * ql_ptr = base_ptr + ql_offset;
const uint8_t * qh_ptr = base_ptr + qh_offset;
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
dst_t * y = yy + ib * QK_K + 128 * ip + il;
const uint8_t * ql = ql_ptr + 64 * ip + il;
const uint8_t qh = *(qh_ptr + 32 * ip + il);
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}
template<typename dst_t>
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1,

View File

@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
!g_ggml_sycl_disable_optimize) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
case GGML_TYPE_Q4_0:
return true;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return !g_ggml_sycl_prioritize_dmmv;
default:
return false;
@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return true;
default:
return false;
@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
sycl::free(tmp_buf, *stream);
}
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
const int nblocks = size / sizeof(block_q6_K);
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
auto * ql_ptr = data_device;
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
stream
->parallel_for(nblocks,
[=](auto i) {
const block_q6_K * x = (const block_q6_K *) tmp_buf;
const int ib = i;
const uint8_t * ql = x[ib].ql;
const uint8_t * qh = x[ib].qh;
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
for (int j = 0; j < QK_K / 2; ++j) {
base_ql_ptr[j] = ql[j];
}
for (int j = 0; j < QK_K / 4; ++j) {
base_qh_ptr[j] = qh[j];
}
for (int j = 0; j < QK_K / 16; ++j) {
base_scales_ptr[j] = x[ib].scales[j];
}
dm_ptr[ib] = x[ib].d;
})
.wait_and_throw();
sycl::free(tmp_buf, *stream);
}
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
uint8_t * data_device = (uint8_t *) src0->data;
size_t ncols = src0->ne[0];
@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
case GGML_TYPE_Q4_K:
reorder_qw_q4_k(data_device, size, 0, stream);
break;
case GGML_TYPE_Q6_K:
reorder_qw_q6_k(data_device, size, 0, stream);
break;
default:
GGML_ABORT("reorder_qw() called with unsupported type");
break;

View File

@ -32,10 +32,9 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
float partial_sum = 0.0f;
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
const int ibx = row * blocks_per_row + i; // x block index
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
const int bx_offset = block_type::get_block_offset(ibx);
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
// Y block index that aligns with ibx
const int iby = i * block_type::block_to_q8_1_ratio();
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
// x block quant index when casting the quants to int
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
}
}
@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
}
}
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
nd_item);
});
});
}
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);

View File

@ -14,12 +14,13 @@
#ifndef GGML_SYCL_QUANTS_HPP
#define GGML_SYCL_QUANTS_HPP
#include <utility>
#include "ggml-common.h"
#include "ggml.h"
namespace ggml_sycl_reordered {
// The reordered block moves quants (qs) and scales(d) to two
// uniform regions of memory that is contiguous in the same tensor.
// What this means is that instead of having:
@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
template <ggml_type type> struct block_q_t;
// qk number of weights / quants in a block
// qr number of weights in a byte (described as 'before dequantization')
// for quantization types that has low and high bits split, qr is calculated with
@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
static constexpr uint32_t vdr_mmvq = 2;
};
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
return { block_index * (traits::qk / traits::qr), 0 };
}
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
static constexpr uint32_t vdr_mmvq = 2;
};
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
return { block_index * (traits::qk / traits::qr), 0 };
}
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk));
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
return { nblocks * (QK_K / 2),
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
};
template <> struct block_q_t<GGML_TYPE_Q6_K> {
struct traits {
static constexpr uint32_t qk = QK_K;
static constexpr uint32_t qi = QI6_K;
static constexpr uint32_t qr = QR6_K;
static constexpr uint32_t vdr_mmvq = 1;
};
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
auto low_bits_index = block_index * (traits::qk / traits::qr);
// the index of high bits it's after all low bits
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
return { low_bits_index, high_bits_index };
}
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk));
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
return { block_scales, sb_scale };
}
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
} // namespace ggml_sycl_reordered
#endif // GGML_SYCL_QUANTS_HPP

View File

@ -284,10 +284,11 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
}
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
const sycl::half2 * q8_1_ds, const int & iqs) {
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
int v[q4_0_traits::vdr_mmvq];
int u[2 * q4_0_traits::vdr_mmvq];
@ -346,15 +347,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
using q4_k_traits = typename q4_k_block::traits;
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
const int ib = ibx_offset / (QK_K / 2);
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
const sycl::half2 * q8_1_ds, const int & iqs) {
const int ib = ibx_offset.first / (QK_K / 2);
const uint8_t * base = static_cast<const uint8_t *>(vbq);
const uint8_t * qs = base + ibx_offset;
const int total_qs_bytes = nblocks * (QK_K / 2);
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
const uint8_t * qs = base + ibx_offset.first;
const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
@ -395,6 +396,66 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
}
};
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
using q6_k_traits = typename q6_k_block::traits;
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
const int8_t * __restrict__ scales, const float d,
const float * __restrict__ d8) {
float sumf = 0.0f;
#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
const int sc = scales[4 * i];
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
dpct::sub_sat()); // vi = (vil | vih) - 32
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
}
return d * sumf;
}
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
const int iqs) {
const int ib = ibx_offset.first / (QK_K / 2);
const uint8_t * base = static_cast<const uint8_t *>(vbq);
const uint8_t * ql = base + ibx_offset.first;
const uint8_t * qh = base + ibx_offset.second;
const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
const int vl = get_int_from_uint8(ql, iqs);
const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
const int8_t * scs = scales + scale_offset;
int u[QR6_K];
float d8[QR6_K];
#pragma unroll
for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
d8[i] = ds_values[0];
}
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
}
};
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4