From c52d1035de6ddaa388f2a5d4bfd0713d447a4ae8 Mon Sep 17 00:00:00 2001 From: PAB Date: Mon, 2 Dec 2024 19:27:24 +0100 Subject: [PATCH] feat: add `GGML_UNARY_OP_ARGMAX` Metal kernel (ggml/1019) * implemented argmax kernel * tpig -> tgpig * change to strides * contiguous assertions * kernel working and tested * argmax simd parallel implementation * added 2 new tests for argmax in test-backend-ops * cosmit * added 3 tests cases for perf eval * add test_argmax in make_test_cases_perf * Update test-backend-ops.cpp Co-authored-by: Diego Devesa --------- Co-authored-by: Diego Devesa --- ggml/src/ggml-metal/ggml-metal.m | 28 ++++++++++++++ ggml/src/ggml-metal/ggml-metal.metal | 57 ++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 88874b90..88931345 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -352,6 +352,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, + GGML_METAL_KERNEL_TYPE_ARGMAX, GGML_METAL_KERNEL_TYPE_COUNT }; @@ -876,6 +877,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } @@ -1005,6 +1007,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: return has_simdgroup_reduction; + case GGML_OP_ARGMAX: case GGML_OP_NORM: case GGML_OP_ROPE: return true; @@ -3615,6 +3618,31 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; } break; + case GGML_OP_ARGMAX: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + + const int64_t nrows = ggml_nrows(src0); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 770a58f2..ecb60379 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1248,6 +1248,63 @@ kernel void kernel_ssm_scan_f32( } } +kernel void kernel_argmax( + device const void * x, + device int32_t * dst, + constant int64_t & ncols, + constant uint64_t & nb01, + threadgroup float * shared_maxval [[threadgroup(0)]], + threadgroup int32_t * shared_argmax [[threadgroup(1)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + + float lmax = -INFINITY; + int32_t larg = -1; + + for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + if (x_row[i00] > lmax) { + lmax = x_row[i00]; + larg = i00; + } + } + + // find the argmax value in the block + float max_val = simd_max(lmax); + int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + shared_maxval[tiisg] = -INFINITY; + shared_argmax[tiisg] = -1; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shared_maxval[sgitg] = max_val; + shared_argmax[sgitg] = arg_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = shared_maxval[tiisg]; + arg_val = shared_argmax[tiisg]; + + float max_val_reduced = simd_max(max_val); + int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); + + dst[tgpig] = arg_val_reduced; + + return; + } + + dst[tgpig] = arg_val; +} + kernel void kernel_norm( constant ggml_metal_kargs_norm & args, device const char * src0,