mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-19 03:06:26 +00:00
feat: add GGML_UNARY_OP_ARGMAX
Metal kernel (ggml/1019)
* implemented argmax kernel * tpig -> tgpig * change to strides * contiguous assertions * kernel working and tested * argmax simd parallel implementation * added 2 new tests for argmax in test-backend-ops * cosmit * added 3 tests cases for perf eval * add test_argmax in make_test_cases_perf * Update test-backend-ops.cpp Co-authored-by: Diego Devesa <slarengh@gmail.com> --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
5773a14980
commit
c52d1035de
@ -352,6 +352,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
};
|
||||
@ -876,6 +877,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
}
|
||||
@ -1005,6 +1007,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ROPE:
|
||||
return true;
|
||||
@ -3615,6 +3618,31 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ARGMAX:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||
|
@ -1248,6 +1248,63 @@ kernel void kernel_ssm_scan_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_argmax(
|
||||
device const void * x,
|
||||
device int32_t * dst,
|
||||
constant int64_t & ncols,
|
||||
constant uint64_t & nb01,
|
||||
threadgroup float * shared_maxval [[threadgroup(0)]],
|
||||
threadgroup int32_t * shared_argmax [[threadgroup(1)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
|
||||
|
||||
float lmax = -INFINITY;
|
||||
int32_t larg = -1;
|
||||
|
||||
for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
|
||||
if (x_row[i00] > lmax) {
|
||||
lmax = x_row[i00];
|
||||
larg = i00;
|
||||
}
|
||||
}
|
||||
|
||||
// find the argmax value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
shared_maxval[tiisg] = -INFINITY;
|
||||
shared_argmax[tiisg] = -1;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shared_maxval[sgitg] = max_val;
|
||||
shared_argmax[sgitg] = arg_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = shared_maxval[tiisg];
|
||||
arg_val = shared_argmax[tiisg];
|
||||
|
||||
float max_val_reduced = simd_max(max_val);
|
||||
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
|
||||
|
||||
dst[tgpig] = arg_val_reduced;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
dst[tgpig] = arg_val;
|
||||
}
|
||||
|
||||
kernel void kernel_norm(
|
||||
constant ggml_metal_kargs_norm & args,
|
||||
device const char * src0,
|
||||
|
Loading…
Reference in New Issue
Block a user