mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-29 07:19:51 +00:00
metal: add neg operator (llama/13029)
This commit is contained in:
parent
0287a5c51b
commit
f4ca3e2f9c
@ -481,6 +481,7 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_SQRT,
|
GGML_METAL_KERNEL_TYPE_SQRT,
|
||||||
GGML_METAL_KERNEL_TYPE_SIN,
|
GGML_METAL_KERNEL_TYPE_SIN,
|
||||||
GGML_METAL_KERNEL_TYPE_COS,
|
GGML_METAL_KERNEL_TYPE_COS,
|
||||||
|
GGML_METAL_KERNEL_TYPE_NEG,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||||
@ -1159,6 +1160,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||||
@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
case GGML_UNARY_OP_GELU_QUICK:
|
case GGML_UNARY_OP_GELU_QUICK:
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
case GGML_UNARY_OP_ELU:
|
case GGML_UNARY_OP_ELU:
|
||||||
|
case GGML_UNARY_OP_NEG:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_UNARY_OP_NEG:
|
||||||
|
{
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||||
|
@ -949,6 +949,13 @@ kernel void kernel_cos(
|
|||||||
dst[tpig] = cos(src0[tpig]);
|
dst[tpig] = cos(src0[tpig]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_neg(
|
||||||
|
device const float * src0,
|
||||||
|
device float * dst,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
dst[tpig] = -src0[tpig];
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_sum_rows(
|
kernel void kernel_sum_rows(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user