metal : add mean kernel (llama/14267)

* metal : add mean kernel

ggml-ci

* cont : dedup implementation

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-19 08:05:21 +03:00
parent 203451bcba
commit b251d739ad
2 changed files with 67 additions and 14 deletions

View File

@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX, GGML_METAL_KERNEL_TYPE_ARGMAX,
@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG: case GGML_OP_LOG:
return false; // TODO: implement return false; // TODO: implement
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{ {
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; id<MTLComputePipelineState> pipeline = nil;
switch (dst->op) {
case GGML_OP_SUM_ROWS:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
break;
case GGML_OP_MEAN:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = { ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {

View File

@ -993,31 +993,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig]; dst[tpig] = -src0[tpig];
} }
template <bool norm>
kernel void kernel_sum_rows( kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0, device const float * src0,
device float * dst, device float * dst,
constant ggml_metal_kargs_sum_rows & args, threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tpig[[thread_position_in_grid]]) { uint3 tgpig[[threadgroup_position_in_grid]],
int64_t i3 = tpig.z; ushort3 tpitg[[thread_position_in_threadgroup]],
int64_t i2 = tpig.y; ushort sgitg[[simdgroup_index_in_threadgroup]],
int64_t i1 = tpig.x; ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return; return;
} }
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float row_sum = 0; float sumf = 0;
for (int64_t i0 = 0; i0 < args.ne00; i0++) { for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
row_sum += src_row[i0]; sumf += src_row[i0];
} }
dst_row[0] = row_sum; sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
} }
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T> template<typename T>
kernel void kernel_soft_max( kernel void kernel_soft_max(
device const char * src0, device const char * src0,