mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-28 04:54:13 +00:00
sycl : implementation of reordered Q4_0 MMVQ for Intel GPUs (llama/12858)
* sycl : Implemented reorder Q4_0 mmvq Signed-off-by: Alberto Cabrera <alberto.cabrera@codeplay.com> * sycl : Fixed mmvq being called when reorder is disabled * sycl : Improved comments in the quants header Signed-off-by: Alberto Cabrera <alberto.cabrera@codeplay.com> * Use static_assert * safe_div -> ceil_div * Clarify qi comment * change the reorder tensor from init to execute OP * dbg * Undo changes to test-backend-ops * Refactor changes on top of q4_0 reorder fix * Missing Reverts * Refactored opt_for_reorder logic to simplify code path * Explicit inlining and unroll * Renamed mul_mat_algo enum for consistency --------- Signed-off-by: Alberto Cabrera <alberto.cabrera@codeplay.com> Co-authored-by: romain.biessy <romain.biessy@codeplay.com>
This commit is contained in:
parent
2d436bfbfb
commit
45d8b2352e
@ -14,23 +14,24 @@
|
||||
#define GGML_SYCL_BACKEND_HPP
|
||||
|
||||
#include "binbcast.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "common.hpp"
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "gla.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "mmq.hpp"
|
||||
#include "mmvq.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "gla.hpp"
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
@ -42,6 +42,7 @@ void ggml_sycl_host_free(void* ptr);
|
||||
|
||||
extern int g_ggml_sycl_debug;
|
||||
extern int g_ggml_sycl_disable_optimize;
|
||||
extern int g_ggml_sycl_prioritize_dmmv;
|
||||
|
||||
#define GGML_SYCL_DEBUG(...) \
|
||||
do { \
|
||||
|
@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
||||
int g_ggml_sycl_debug = 0;
|
||||
int g_ggml_sycl_disable_optimize = 0;
|
||||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
@ -195,11 +196,13 @@ static void ggml_check_sycl() try {
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||
GGML_LOG_INFO("Build with Macros:\n");
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
enum class mul_mat_algo {
|
||||
DMMV = 0,
|
||||
MMVQ = 1,
|
||||
MUL_MAT_SYCL = 2,
|
||||
};
|
||||
|
||||
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
// TODO: accuracy issues in MMQ
|
||||
GGML_UNUSED(type);
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
|
||||
auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
||||
|
||||
stream->parallel_for(
|
||||
@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
}
|
||||
|
||||
/*
|
||||
* This function could be called when the OP (mul_mat) function support reorder optimizition.
|
||||
*/
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
|
||||
return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
|
||||
ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
|
||||
dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
|
||||
src0->type == GGML_TYPE_Q4_0 &&
|
||||
src1->ne[2]==1 && src1->ne[3]==1) {
|
||||
dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
|
||||
if (!extra) return; //only happen in CI/UT permute case.
|
||||
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
|
||||
ggml_tensor * dst, mul_mat_algo mm_algorithm) {
|
||||
if (!should_reorder_tensor(*ctx, dst)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
|
||||
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
|
||||
if (!extra || extra->optimized_feature.reorder) {
|
||||
return; // Skip permutations and already reordered tensors
|
||||
}
|
||||
|
||||
switch (mm_algorithm) {
|
||||
case mul_mat_algo::DMMV:
|
||||
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case mul_mat_algo::MMVQ:
|
||||
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case mul_mat_algo::MUL_MAT_SYCL:
|
||||
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
reorder_qw(src0, ctx->stream());
|
||||
extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
|
||||
}
|
||||
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
int64_t min_compute_capability = INT_MAX;
|
||||
|
||||
if (split) {
|
||||
ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
ggml_backend_sycl_split_buffer_type_context * buft_ctx =
|
||||
(ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
auto & tensor_split = buft_ctx->tensor_split;
|
||||
for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
|
||||
// skip devices that are not going to do any work:
|
||||
@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
||||
#endif // SYCL_USE_XMX
|
||||
|
||||
|
||||
// mmvq path is faster in the CUDA backend.
|
||||
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
|
||||
if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
|
||||
// Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
|
||||
// is enabled takes precedence over DMMV, the current if-else implementation
|
||||
// requires disabling DMMV if both conditions are met
|
||||
|| (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// TODO: Refactor and cleanup of mul mat dispatching.
|
||||
@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
||||
// KQ + KQV multi-batch
|
||||
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
|
||||
} else if (use_dequantize_mul_mat_vec) {
|
||||
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
|
||||
// save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
|
||||
} else if (use_mul_mat_vec_q) {
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
|
||||
constexpr bool convert_src1_to_q8_1 = true;
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
|
||||
} else if (use_mul_mat_q) {
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
|
||||
constexpr bool convert_src1_to_q8_1 = true;
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
|
||||
} else {
|
||||
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
|
||||
constexpr bool convert_src1_to_q8_1 = false;
|
||||
// MUL_MAT_SYCL supports reorder
|
||||
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
|
||||
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
|
||||
}
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,6 +1,60 @@
|
||||
#include "mmvq.hpp"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "common.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "vecdotq.hpp"
|
||||
#include <cassert>
|
||||
|
||||
template <typename reorder_vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
|
||||
using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
|
||||
using block_traits = typename block_type::traits;
|
||||
|
||||
const auto sg = nd_item.get_sub_group();
|
||||
const int sg_range = sg.get_group_linear_range();
|
||||
const int workgroup_id = nd_item.get_group_linear_id();
|
||||
const int sg_id = sg.get_group_linear_id();
|
||||
const int row = workgroup_id * sg_range + sg_id;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int blocks_per_row = ncols / block_traits::qk;
|
||||
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
||||
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
||||
|
||||
static_assert(blocks_per_subgroup > 0);
|
||||
static_assert(block_elements_per_subgroup > 0);
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
float partial_sum = 0.0f;
|
||||
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
||||
const int ibx = row * blocks_per_row + i; // x block index
|
||||
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
||||
const int bx_offset = block_type::get_block_offset(ibx);
|
||||
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
||||
|
||||
// Y block index that aligns with ibx
|
||||
const int iby = i * block_type::block_to_q8_1_ratio();
|
||||
|
||||
#pragma unroll
|
||||
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
||||
// x block quant index when casting the quants to int
|
||||
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
||||
|
||||
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
|
||||
}
|
||||
}
|
||||
|
||||
auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
|
||||
|
||||
if (sg.leader()) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
||||
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
@ -480,24 +534,37 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
const int nrows,
|
||||
static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
||||
const int nrows, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
||||
constexpr size_t num_subgroups = 16;
|
||||
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
||||
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
@ -916,14 +983,11 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_mul_mat_vec_q(
|
||||
ggml_backend_sycl_context & ctx,
|
||||
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
|
||||
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
|
||||
const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
|
||||
const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
||||
const dpct::queue_ptr & stream) {
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
||||
@ -931,21 +995,26 @@ void ggml_sycl_op_mul_mat_vec_q(
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
int id;
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||
const size_t q8_1_ts = sizeof(block_q8_1);
|
||||
const size_t q8_1_bs = QK8_1;
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
|
||||
for (int i = 0; i < src1_ncols; i++)
|
||||
{
|
||||
for (int i = 0; i < src1_ncols; i++) {
|
||||
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
|
||||
const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
|
||||
float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
||||
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
||||
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
|
||||
reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
|
||||
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
|
61
ggml/src/ggml-sycl/quants.hpp
Normal file
61
ggml/src/ggml-sycl/quants.hpp
Normal file
@ -0,0 +1,61 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Codeplay Software Ltd.
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_QUANTS_HPP
|
||||
#define GGML_SYCL_QUANTS_HPP
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
namespace ggml_sycl_reordered {
|
||||
|
||||
|
||||
// The reordered block moves quants (qs) and scales(d) to two
|
||||
// uniform regions of memory that is contiguous in the same tensor.
|
||||
// What this means is that instead of having:
|
||||
// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN]
|
||||
// We have:
|
||||
// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN]
|
||||
//
|
||||
// Notes: out-of-bounds qs will run into d values
|
||||
// Aligment relies on the allocated size of qs
|
||||
|
||||
template <ggml_type type> struct block_q_t;
|
||||
|
||||
|
||||
// qk number of weights / quants in a block
|
||||
// qr number of weights in a byte (described as 'before dequantization')
|
||||
// for quantization types that has low and high bits split, qr is calculated with
|
||||
// using the lower bits, e.g for Q6 quants QR6 is 2
|
||||
// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field)
|
||||
// See ggml-common.h to see how these are calculated
|
||||
template <> struct block_q_t<GGML_TYPE_Q4_0> {
|
||||
struct traits {
|
||||
static constexpr uint32_t qk = QK4_0;
|
||||
static constexpr uint32_t qi = QI4_0;
|
||||
static constexpr uint32_t qr = QR4_0;
|
||||
static constexpr uint32_t vdr_mmvq = 2;
|
||||
};
|
||||
|
||||
static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
|
||||
|
||||
static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
|
||||
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
|
||||
}
|
||||
|
||||
static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
|
||||
};
|
||||
|
||||
} // namespace ggml_sycl_reordered
|
||||
|
||||
#endif // GGML_SYCL_QUANTS_HPP
|
@ -1,6 +1,6 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
@ -14,8 +14,11 @@
|
||||
#define GGML_SYCL_VECDOTQ_HPP
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
#include "quants.hpp"
|
||||
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
const int & iqs);
|
||||
|
||||
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
|
||||
const uint16_t* x16 =
|
||||
@ -252,12 +255,59 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
template <ggml_type T> struct reorder_vec_dot_q_sycl {
|
||||
static_assert(T != T, "ggml_type for reorder vecdot not implemented");
|
||||
};
|
||||
|
||||
template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
|
||||
static constexpr ggml_type gtype = GGML_TYPE_Q4_0;
|
||||
|
||||
using q4_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_0>;
|
||||
using q4_0_traits = typename q4_0_block::traits;
|
||||
|
||||
__dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) {
|
||||
int sumi = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
||||
const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
|
||||
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
|
||||
|
||||
// SIMD dot product of quantized values
|
||||
sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi);
|
||||
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
|
||||
}
|
||||
|
||||
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y());
|
||||
}
|
||||
|
||||
__dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
|
||||
const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
|
||||
const ggml_half d = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
|
||||
int v[q4_0_traits::vdr_mmvq];
|
||||
int u[2 * q4_0_traits::vdr_mmvq];
|
||||
|
||||
#pragma unroll
|
||||
|
||||
for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) {
|
||||
v[i] = get_int_from_uint8(bq4_0, iqs + i);
|
||||
u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
|
||||
u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + q4_0_traits::qi);
|
||||
}
|
||||
|
||||
return vec_dot_q4_0_q8_1_impl(v, u, d, bq8_1->ds);
|
||||
};
|
||||
};
|
||||
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
template <int vdr>
|
||||
static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
|
||||
const float &d4,
|
||||
static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4,
|
||||
const sycl::half2 & ds8) {
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
@ -270,8 +320,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u,
|
||||
sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi);
|
||||
}
|
||||
|
||||
const sycl::float2 ds8f =
|
||||
ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
const sycl::float2 ds8f = ds8.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y());
|
||||
|
Loading…
x
Reference in New Issue
Block a user