From b251d739ad057fa2dc84a6e2c2ec49f76237fe7e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 19 Jun 2025 08:05:21 +0300 Subject: [PATCH] metal : add mean kernel (llama/14267) * metal : add mean kernel ggml-ci * cont : dedup implementation ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++--- ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index bc93bc63..4e7f373c 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -498,6 +498,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, GGML_METAL_KERNEL_TYPE_ARGMAX, @@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); @@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_LOG: return false; // TODO: implement case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); @@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: { GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + id pipeline = nil; + switch (dst->op) { + case GGML_OP_SUM_ROWS: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + break; + case GGML_OP_MEAN: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline; + break; + default: + GGML_ABORT("fatal error"); + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00); ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, @@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&args length:sizeof(args) atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_SOFT_MAX: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5d776021..3da19879 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -993,31 +993,61 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +template kernel void kernel_sum_rows( + constant ggml_metal_kargs_sum_rows & args, device const float * src0, device float * dst, - constant ggml_metal_kargs_sum_rows & args, - uint3 tpig[[thread_position_in_grid]]) { - int64_t i3 = tpig.z; - int64_t i2 = tpig.y; - int64_t i1 = tpig.x; + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + int64_t i3 = tgpig.z; + int64_t i2 = tgpig.y; + int64_t i1 = tgpig.x; if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { return; } + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float row_sum = 0; + float sumf = 0; - for (int64_t i0 = 0; i0 < args.ne00; i0++) { - row_sum += src_row[i0]; + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + sumf += src_row[i0]; } - dst_row[0] = row_sum; + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + if (tpitg.x == 0) { + dst_row[0] = norm ? sumf / args.ne00 : sumf; + } } +typedef decltype(kernel_sum_rows) kernel_sum_rows_t; + +template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows; + template kernel void kernel_soft_max( device const char * src0,