mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-20 07:33:50 +00:00
sycl: Add reorder to Q6_K mmvq implementation (llama/13885)
* Add Reorder to Q6_K mmvq implementation * Address PR comments: clean up comments * Remove unused parameter after refactoring q4_k * Adding inline to function and removing unnecessary reference to int --------- Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
committed by
Georgi Gerganov
parent
8a70f4d18b
commit
4737a8c780
@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||||
|
const int64_t nb = k / QK_K;
|
||||||
|
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
||||||
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
|||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
return dequantize_row_q5_K_sycl;
|
return dequantize_row_q5_K_sycl;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||||
|
return dequantize_row_q6_K_sycl_reorder;
|
||||||
|
} else {
|
||||||
return dequantize_row_q6_K_sycl;
|
return dequantize_row_q6_K_sycl;
|
||||||
|
}
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
return dequantize_row_iq1_s_sycl;
|
return dequantize_row_iq1_s_sycl;
|
||||||
case GGML_TYPE_IQ1_M:
|
case GGML_TYPE_IQ1_M:
|
||||||
@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
return dequantize_row_q5_K_sycl;
|
return dequantize_row_q5_K_sycl;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||||
|
return dequantize_row_q6_K_sycl_reorder;
|
||||||
|
} else {
|
||||||
return dequantize_row_q6_K_sycl;
|
return dequantize_row_q6_K_sycl;
|
||||||
|
}
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
return dequantize_row_iq1_s_sycl;
|
return dequantize_row_iq1_s_sycl;
|
||||||
case GGML_TYPE_IQ1_M:
|
case GGML_TYPE_IQ1_M:
|
||||||
|
@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
|
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
|
||||||
|
const int64_t ib = item_ct1.get_group(2);
|
||||||
|
|
||||||
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
|
const int64_t ip = tid / 32; // ip is 0 or 1
|
||||||
|
const int64_t il = tid - 32 * ip; // 0...32
|
||||||
|
const int64_t is = 8 * ip + il / 16;
|
||||||
|
|
||||||
|
const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
|
||||||
|
const auto ql_offset = ib * (QK_K / 2);
|
||||||
|
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
|
||||||
|
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
|
||||||
|
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
|
||||||
|
const uint8_t * ql_ptr = base_ptr + ql_offset;
|
||||||
|
const uint8_t * qh_ptr = base_ptr + qh_offset;
|
||||||
|
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
|
||||||
|
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;
|
||||||
|
|
||||||
|
dst_t * y = yy + ib * QK_K + 128 * ip + il;
|
||||||
|
|
||||||
|
const uint8_t * ql = ql_ptr + 64 * ip + il;
|
||||||
|
const uint8_t qh = *(qh_ptr + 32 * ip + il);
|
||||||
|
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);
|
||||||
|
|
||||||
|
y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
||||||
|
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
||||||
|
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
||||||
|
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1,
|
const sycl::nd_item<3> &item_ct1,
|
||||||
|
@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
|
|||||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
|
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
|
||||||
|
!g_ggml_sycl_disable_optimize) {
|
||||||
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
|
||||||
tensor->extra = extra;
|
tensor->extra = extra;
|
||||||
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
|
||||||
@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
|||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return true;
|
return true;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
return !g_ggml_sycl_prioritize_dmmv;
|
return !g_ggml_sycl_prioritize_dmmv;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
|||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|||||||
sycl::free(tmp_buf, *stream);
|
sycl::free(tmp_buf, *stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||||
|
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
|
||||||
|
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);
|
||||||
|
|
||||||
|
const int nblocks = size / sizeof(block_q6_K);
|
||||||
|
|
||||||
|
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||||
|
|
||||||
|
auto * ql_ptr = data_device;
|
||||||
|
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
||||||
|
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
||||||
|
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
||||||
|
|
||||||
|
stream
|
||||||
|
->parallel_for(nblocks,
|
||||||
|
[=](auto i) {
|
||||||
|
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
||||||
|
const int ib = i;
|
||||||
|
|
||||||
|
const uint8_t * ql = x[ib].ql;
|
||||||
|
const uint8_t * qh = x[ib].qh;
|
||||||
|
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
|
||||||
|
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
|
||||||
|
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K / 2; ++j) {
|
||||||
|
base_ql_ptr[j] = ql[j];
|
||||||
|
}
|
||||||
|
for (int j = 0; j < QK_K / 4; ++j) {
|
||||||
|
base_qh_ptr[j] = qh[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < QK_K / 16; ++j) {
|
||||||
|
base_scales_ptr[j] = x[ib].scales[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
dm_ptr[ib] = x[ib].d;
|
||||||
|
})
|
||||||
|
.wait_and_throw();
|
||||||
|
|
||||||
|
sycl::free(tmp_buf, *stream);
|
||||||
|
}
|
||||||
|
|
||||||
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||||
uint8_t * data_device = (uint8_t *) src0->data;
|
uint8_t * data_device = (uint8_t *) src0->data;
|
||||||
size_t ncols = src0->ne[0];
|
size_t ncols = src0->ne[0];
|
||||||
@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
|||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
reorder_qw_q4_k(data_device, size, 0, stream);
|
reorder_qw_q4_k(data_device, size, 0, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
reorder_qw_q6_k(data_device, size, 0, stream);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("reorder_qw() called with unsupported type");
|
GGML_ABORT("reorder_qw() called with unsupported type");
|
||||||
break;
|
break;
|
||||||
|
@ -32,10 +32,9 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
|||||||
float partial_sum = 0.0f;
|
float partial_sum = 0.0f;
|
||||||
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
||||||
const int ibx = row * blocks_per_row + i; // x block index
|
const int ibx = row * blocks_per_row + i; // x block index
|
||||||
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
|
||||||
const int bx_offset = block_type::get_block_offset(ibx);
|
|
||||||
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
|
||||||
|
|
||||||
|
const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
|
||||||
|
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||||
// Y block index that aligns with ibx
|
// Y block index that aligns with ibx
|
||||||
const int iby = i * block_type::block_to_q8_1_ratio();
|
const int iby = i * block_type::block_to_q8_1_ratio();
|
||||||
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
|
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
|
||||||
@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
|
|||||||
// x block quant index when casting the quants to int
|
// x block quant index when casting the quants to int
|
||||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||||
|
|
||||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
|
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||||
|
const int nrows, dpct::queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
|
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||||
|
constexpr size_t num_subgroups = 16;
|
||||||
|
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||||
|
|
||||||
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||||
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||||
|
|
||||||
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||||
|
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
||||||
|
nd_item);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
float *dst, const int ncols,
|
float *dst, const int ncols,
|
||||||
const int nrows,
|
const int nrows,
|
||||||
@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|||||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
|
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||||
|
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||||
|
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
|
||||||
|
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||||
|
} else {
|
||||||
|
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
|
||||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||||
|
@ -14,12 +14,13 @@
|
|||||||
#ifndef GGML_SYCL_QUANTS_HPP
|
#ifndef GGML_SYCL_QUANTS_HPP
|
||||||
#define GGML_SYCL_QUANTS_HPP
|
#define GGML_SYCL_QUANTS_HPP
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "ggml-common.h"
|
#include "ggml-common.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
namespace ggml_sycl_reordered {
|
namespace ggml_sycl_reordered {
|
||||||
|
|
||||||
|
|
||||||
// The reordered block moves quants (qs) and scales(d) to two
|
// The reordered block moves quants (qs) and scales(d) to two
|
||||||
// uniform regions of memory that is contiguous in the same tensor.
|
// uniform regions of memory that is contiguous in the same tensor.
|
||||||
// What this means is that instead of having:
|
// What this means is that instead of having:
|
||||||
@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {
|
|||||||
|
|
||||||
template <ggml_type type> struct block_q_t;
|
template <ggml_type type> struct block_q_t;
|
||||||
|
|
||||||
|
|
||||||
// qk number of weights / quants in a block
|
// qk number of weights / quants in a block
|
||||||
// qr number of weights in a byte (described as 'before dequantization')
|
// qr number of weights in a byte (described as 'before dequantization')
|
||||||
// for quantization types that has low and high bits split, qr is calculated with
|
// for quantization types that has low and high bits split, qr is calculated with
|
||||||
@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
|||||||
static constexpr uint32_t vdr_mmvq = 2;
|
static constexpr uint32_t vdr_mmvq = 2;
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||||
|
return { block_index * (traits::qk / traits::qr), 0 };
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||||
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||||
@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
|
|||||||
static constexpr uint32_t vdr_mmvq = 2;
|
static constexpr uint32_t vdr_mmvq = 2;
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
|
||||||
|
return { block_index * (traits::qk / traits::qr), 0 };
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||||
auto nblocks = (nrows * (ncols / traits::qk));
|
auto nblocks = (nrows * (ncols / traits::qk));
|
||||||
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
|
return { nblocks * (QK_K / 2),
|
||||||
|
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||||
|
|
||||||
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
|
||||||
|
|
||||||
constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <> struct block_q_t<GGML_TYPE_Q6_K> {
|
||||||
|
struct traits {
|
||||||
|
static constexpr uint32_t qk = QK_K;
|
||||||
|
static constexpr uint32_t qi = QI6_K;
|
||||||
|
static constexpr uint32_t qr = QR6_K;
|
||||||
|
static constexpr uint32_t vdr_mmvq = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
|
||||||
|
auto low_bits_index = block_index * (traits::qk / traits::qr);
|
||||||
|
// the index of high bits it's after all low bits
|
||||||
|
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
|
||||||
|
return { low_bits_index, high_bits_index };
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
|
||||||
|
auto nblocks = (nrows * (ncols / traits::qk));
|
||||||
|
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
|
||||||
|
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
|
||||||
|
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
|
||||||
|
return { block_scales, sb_scale };
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||||
|
};
|
||||||
} // namespace ggml_sycl_reordered
|
} // namespace ggml_sycl_reordered
|
||||||
|
|
||||||
#endif // GGML_SYCL_QUANTS_HPP
|
#endif // GGML_SYCL_QUANTS_HPP
|
||||||
|
@ -284,10 +284,11 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
|||||||
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
||||||
}
|
}
|
||||||
|
|
||||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||||
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int /* nblocks */) {
|
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
||||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
const sycl::half2 * q8_1_ds, const int & iqs) {
|
||||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset.first;
|
||||||
|
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset.first));
|
||||||
int v[q4_0_traits::vdr_mmvq];
|
int v[q4_0_traits::vdr_mmvq];
|
||||||
int u[2 * q4_0_traits::vdr_mmvq];
|
int u[2 * q4_0_traits::vdr_mmvq];
|
||||||
|
|
||||||
@ -346,15 +347,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
|||||||
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
using q4_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
|
||||||
using q4_k_traits = typename q4_k_block::traits;
|
using q4_k_traits = typename q4_k_block::traits;
|
||||||
|
|
||||||
float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||||
const int8_t* q8_1_quant_ptr, const sycl::half2* q8_1_ds, const int & iqs, int nblocks) {
|
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr,
|
||||||
const int ib = ibx_offset / (QK_K / 2);
|
const sycl::half2 * q8_1_ds, const int & iqs) {
|
||||||
|
const int ib = ibx_offset.first / (QK_K / 2);
|
||||||
|
|
||||||
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||||
const uint8_t * qs = base + ibx_offset;
|
const uint8_t * qs = base + ibx_offset.first;
|
||||||
const int total_qs_bytes = nblocks * (QK_K / 2);
|
const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE;
|
||||||
const uint8_t * scs = base + total_qs_bytes + ib * K_SCALE_SIZE;
|
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second);
|
||||||
const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset);
|
|
||||||
|
|
||||||
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
|
||||||
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
|
||||||
@ -395,6 +396,66 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> {
|
||||||
|
static constexpr ggml_type gtype = GGML_TYPE_Q6_K;
|
||||||
|
|
||||||
|
using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>;
|
||||||
|
using q6_k_traits = typename q6_k_block::traits;
|
||||||
|
|
||||||
|
__dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u,
|
||||||
|
const int8_t * __restrict__ scales, const float d,
|
||||||
|
const float * __restrict__ d8) {
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < QR6_K; ++i) {
|
||||||
|
const int sc = scales[4 * i];
|
||||||
|
|
||||||
|
const int vil = (vl >> (4 * i)) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
const int vih = ((vh >> (4 * i)) << 4) & 0x30303030;
|
||||||
|
|
||||||
|
const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020,
|
||||||
|
dpct::sub_sat()); // vi = (vil | vih) - 32
|
||||||
|
|
||||||
|
sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product
|
||||||
|
}
|
||||||
|
|
||||||
|
return d * sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
__dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset,
|
||||||
|
const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds,
|
||||||
|
const int iqs) {
|
||||||
|
const int ib = ibx_offset.first / (QK_K / 2);
|
||||||
|
|
||||||
|
const uint8_t * base = static_cast<const uint8_t *>(vbq);
|
||||||
|
const uint8_t * ql = base + ibx_offset.first;
|
||||||
|
const uint8_t * qh = base + ibx_offset.second;
|
||||||
|
const int8_t * scales = reinterpret_cast<const int8_t *>(base + d_offset.first);
|
||||||
|
const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib;
|
||||||
|
|
||||||
|
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4);
|
||||||
|
const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8);
|
||||||
|
const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4));
|
||||||
|
|
||||||
|
const int vl = get_int_from_uint8(ql, iqs);
|
||||||
|
const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift;
|
||||||
|
|
||||||
|
const int8_t * scs = scales + scale_offset;
|
||||||
|
|
||||||
|
int u[QR6_K];
|
||||||
|
float d8[QR6_K];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < QR6_K; ++i) {
|
||||||
|
u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1);
|
||||||
|
const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i);
|
||||||
|
d8[i] = ds_values[0];
|
||||||
|
}
|
||||||
|
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8);
|
||||||
|
}
|
||||||
|
};
|
||||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user