From 4b60ff4f92bd4a767d5aff693484dc8255ec7672 Mon Sep 17 00:00:00 2001 From: Gian-Carlo Pascutto Date: Tue, 25 Feb 2025 10:27:58 +0100 Subject: [PATCH] metal : copy kernels for quant to F32/F16 conversions (llama/12017) metal: use dequantize_q templates --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal.m | 82 ++++++++++++++++++++++++++-- ggml/src/ggml-metal/ggml-metal.metal | 43 +++++++++++++++ 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 087e7f58..c550142a 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -407,6 +407,16 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQRT, @@ -1012,6 +1022,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); @@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } default: return false; }; @@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node( case GGML_OP_CPY: case GGML_OP_CONT: { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - id pipeline = nil; switch (src0t) { @@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node( switch (dstt) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q4_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q5_1: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_Q8_0: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; + default: GGML_ABORT("not implemented"); }; } break; default: GGML_ABORT("not implemented"); @@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SET: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 83e7ac9f..d092a169 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl( } } +template +kernel void kernel_cpy_q_f32( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + + device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + T4x4 temp; + dequantize_func(src_data + i00/nl, i00%nl, temp); + dst_data[i00] = temp; + } +} + +typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; + +template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; + +template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; + kernel void kernel_concat( constant ggml_metal_kargs_concat & args, device const char * src0,