From 3563473d2c5dfda9d06a67d61787f026fa1de977 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 29 May 2024 07:00:24 +0800 Subject: [PATCH] Align GEMM dispatch (llama/7566) * align GEMM dispatch --- ggml-sycl.cpp | 122 +++++++++++++++++++++++--------------------------- 1 file changed, 55 insertions(+), 67 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index dccfe9eb..a7344813 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3022,20 +3022,19 @@ static int g_work_group_size = 0; // typedef sycl::half ggml_fp16_t; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 //todo for hardward optimize. +#define VER_4VEC 130 //todo for hardward optimize. #define VER_GEN9 700 //todo for hardward optimize. #define VER_GEN12 1000000 //todo for hardward optimize. #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize. #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares - -//define for XMX in Intel GPU -//TODO: currently, it's not used for XMX really. -#define SYCL_USE_XMX +#if !defined(GGML_SYCL_FORCE_MMQ) + #define SYCL_USE_XMX +#endif // max batch size to use MMQ kernels when tensor cores are available -#define XMX_MAX_BATCH_SIZE 32 +#define MMQ_MAX_BATCH_SIZE 32 #if defined(_MSC_VER) @@ -15249,6 +15248,29 @@ catch (sycl::exception const &exc) { std::exit(1); } +inline bool ggml_sycl_supports_mmq(enum ggml_type type) { + // TODO: accuracy issues in MMQ + return false; +} + +bool ggml_sycl_supports_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_F16: + return true; + default: + return false; + } +} static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = @@ -15265,76 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } + // check data types and tensor shapes for custom matrix multiplication kernels: + bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + // mmvq and mmq need the __dp4a instruction which is available for gen12+ + // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); #ifdef SYCL_USE_XMX - const bool use_xmx = true; -#else - const bool use_xmx = false; -#endif + use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); +#endif // SYCL_USE_XMX - // debug helpers - //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); - //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); - //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); - //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); - //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); - //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - - if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n"); ggml_sycl_mul_mat_vec_p021(src0, src1, dst); - } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n"); ggml_sycl_mul_mat_vec_nc(src0, src1, dst); - } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n"); ggml_sycl_mul_mat_batched_sycl(src0, src1, dst); - } else if (src0->type == GGML_TYPE_F32) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n"); - if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) { -#ifdef GGML_SYCL_FORCE_DMMV - const bool use_mul_mat_vec_q = false; -#else - bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); - use_mul_mat_vec_q = use_mul_mat_vec_q || - (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) || - (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) || - (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) || - (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M); - - -#endif // GGML_SYCL_FORCE_DMMV - - if (use_mul_mat_vec_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - } - } else { - bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type); - use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); - - if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) { - use_mul_mat_q = false; - } - - if (use_mul_mat_q) { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); - } else { - // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n"); - ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); - } - } + } else if (use_dequantize_mul_mat_vec) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); + } else if (use_mul_mat_vec_q) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + } else if (use_mul_mat_q) { + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true); } else { - GGML_ASSERT(false); + ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); } }