mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-18 02:39:47 +00:00
llama : add qwen2moe (llama/6074)
* support qwen2moe * fix-review * metal : support unary ops for nelements % 4 != 0 * metal : require contiguousness for float4 unary kernels * metal : require contiguousness for float4 unary kernels (cont) * fix-review * names : for brevity "SHARED_EXP" -> "SHEXP" * llama : reuse build_moe_ffn() * llama : add model type name --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
98c0b77e0c
commit
fdb2c87350
57
ggml-metal.m
57
ggml-metal.m
@ -42,8 +42,11 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_RELU,
|
||||
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
||||
GGML_METAL_KERNEL_TYPE_GELU,
|
||||
GGML_METAL_KERNEL_TYPE_GELU_4,
|
||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
||||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
||||
@ -475,8 +478,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||
@ -1181,6 +1187,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
// we are not taking into account the strides, so for now require contiguous tensors
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
case GGML_UNARY_OP_TANH:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
||||
@ -1219,42 +1228,60 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
||||
int64_t n = ggml_nelements(dst);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (n % 4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
|
||||
n /= 4;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
||||
int64_t n = ggml_nelements(dst);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (n % 4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
|
||||
n /= 4;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
||||
int64_t n = ggml_nelements(dst);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (n % 4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
|
||||
n /= 4;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -249,6 +249,15 @@ constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
kernel void kernel_gelu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
@ -262,6 +271,15 @@ kernel void kernel_gelu(
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
@ -271,6 +289,14 @@ kernel void kernel_gelu_quick(
|
||||
}
|
||||
|
||||
kernel void kernel_silu(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
+kernel void kernel_silu_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
Loading…
Reference in New Issue
Block a user