feat: add new sin and cos operators (ggml/919)

* ggml : add sin/cos operators

* ggml-cuda : add sin/cos operators

* ggml : add corresponding tests for sin/cos

* ggml : add backward computation for sin/cos operators

* ggml-vulkan : add sin/cos operators

* ggml-vulkan : add sin/cos shader source

* metal : add sin, cos

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Ronsor 2024-08-12 06:02:08 -07:00 committed by Georgi Gerganov
parent d65786ea54
commit 3643120690
8 changed files with 398 additions and 5 deletions

View File

@ -451,6 +451,8 @@ extern "C" {
GGML_OP_SQR, GGML_OP_SQR,
GGML_OP_SQRT, GGML_OP_SQRT,
GGML_OP_LOG, GGML_OP_LOG,
GGML_OP_SIN,
GGML_OP_COS,
GGML_OP_SUM, GGML_OP_SUM,
GGML_OP_SUM_ROWS, GGML_OP_SUM_ROWS,
GGML_OP_MEAN, GGML_OP_MEAN,
@ -967,6 +969,22 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sin(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_sin_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_cos(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_cos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return scalar // return scalar
GGML_API struct ggml_tensor * ggml_sum( GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx, struct ggml_context * ctx,

View File

@ -2267,6 +2267,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SQRT: case GGML_OP_SQRT:
ggml_cuda_op_sqrt(ctx, dst); ggml_cuda_op_sqrt(ctx, dst);
break; break;
case GGML_OP_SIN:
ggml_cuda_op_sin(ctx, dst);
break;
case GGML_OP_COS:
ggml_cuda_op_cos(ctx, dst);
break;
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
ggml_cuda_op_clamp(ctx, dst); ggml_cuda_op_clamp(ctx, dst);
break; break;
@ -2859,6 +2865,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SQRT: case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_CONT: case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:

View File

@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
dst[i] = sqrtf(x[i]); dst[i] = sqrtf(x[i]);
} }
static __global__ void sin_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = sinf(x[i]);
}
static __global__ void cos_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = cosf(x[i]);
}
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k); gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k); sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} }
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;
@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
} }
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}

View File

@ -9,6 +9,8 @@
#define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256 #define CUDA_SQRT_BLOCK_SIZE 256
#define CUDA_SIN_BLOCK_SIZE 256
#define CUDA_COS_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -205,6 +205,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_COUNT GGML_METAL_KERNEL_TYPE_COUNT
@ -665,6 +667,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
} }
@ -771,9 +775,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_SQR:
case GGML_OP_SUM_ROWS:
return true; return true;
case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
@ -1409,6 +1416,34 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SIN:
{
GGML_ASSERT(ggml_is_contiguous(src0));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_COS:
{
GGML_ASSERT(ggml_is_contiguous(src0));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:

View File

@ -358,6 +358,20 @@ kernel void kernel_sqr(
dst[tpig] = src0[tpig] * src0[tpig]; dst[tpig] = src0[tpig] * src0[tpig];
} }
kernel void kernel_sin(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = sin(src0[tpig]);
}
kernel void kernel_cos(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = cos(src0[tpig]);
}
kernel void kernel_sum_rows( kernel void kernel_sum_rows(
device const float * src0, device const float * src0,
device float * dst, device float * dst,

View File

@ -184,6 +184,8 @@ struct vk_device_struct {
vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
@ -1654,6 +1656,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -3972,6 +3976,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_sqr_f32; return ctx->device->pipeline_sqr_f32;
} }
return nullptr; return nullptr;
case GGML_OP_SIN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_sin_f32;
}
return nullptr;
case GGML_OP_COS:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_cos_f32;
}
return nullptr;
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32; return ctx->device->pipeline_clamp_f32;
@ -4124,6 +4138,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_PAD: case GGML_OP_PAD:
return true; return true;
@ -4335,6 +4351,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_CPY: case GGML_OP_CPY:
@ -4576,6 +4594,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
}); });
} }
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
});
}
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f,
});
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params; float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src0_type_size = ggml_type_size(src0->type);
@ -5481,6 +5525,8 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
case GGML_OP_ADD: case GGML_OP_ADD:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_CPY: case GGML_OP_CPY:
@ -5761,6 +5807,8 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_CPY: case GGML_OP_CPY:
@ -5832,6 +5880,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_SQR: case GGML_OP_SQR:
ggml_vk_sqr(ctx, compute_ctx, src0, node); ggml_vk_sqr(ctx, compute_ctx, src0, node);
break;
case GGML_OP_SIN:
ggml_vk_sin(ctx, compute_ctx, src0, node);
break;
case GGML_OP_COS:
ggml_vk_cos(ctx, compute_ctx, src0, node);
break; break;
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node); ggml_vk_clamp(ctx, compute_ctx, src0, node);
@ -5943,6 +5999,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_CPY: case GGML_OP_CPY:
@ -6658,6 +6716,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
case GGML_OP_PAD: case GGML_OP_PAD:
case GGML_OP_CONT: case GGML_OP_CONT:

View File

@ -2310,7 +2310,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@ -2760,6 +2762,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SQR", "SQR",
"SQRT", "SQRT",
"LOG", "LOG",
"SIN",
"COS",
"SUM", "SUM",
"SUM_ROWS", "SUM_ROWS",
"MEAN", "MEAN",
@ -2833,7 +2837,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -2848,6 +2852,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"x^2", "x^2",
"√x", "√x",
"log(x)", "log(x)",
"sin(x)",
"cos(x)",
"Σx", "Σx",
"Σx_k", "Σx_k",
"Σx/n", "Σx/n",
@ -2921,7 +2927,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4882,6 +4888,72 @@ struct ggml_tensor * ggml_log_inplace(
return ggml_log_impl(ctx, a, true); return ggml_log_impl(ctx, a, true);
} }
// ggml_sin
static struct ggml_tensor * ggml_sin_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_SIN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_sin(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_sin_impl(ctx, a, false);
}
struct ggml_tensor * ggml_sin_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_sin_impl(ctx, a, true);
}
// ggml_cos
static struct ggml_tensor * ggml_cos_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_COS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_cos(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_cos_impl(ctx, a, false);
}
struct ggml_tensor * ggml_cos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_cos_impl(ctx, a, true);
}
// ggml_sum // ggml_sum
struct ggml_tensor * ggml_sum( struct ggml_tensor * ggml_sum(
@ -10512,6 +10584,96 @@ static void ggml_compute_forward_log(
} }
} }
// ggml_compute_forward_sin
static void ggml_compute_forward_sin_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT( dst->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_sin_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_sin(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_sin_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_cos
static void ggml_compute_forward_cos_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT( dst->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_cos_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_cos(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cos_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_sum // ggml_compute_forward_sum
static void ggml_compute_forward_sum_f32( static void ggml_compute_forward_sum_f32(
@ -16787,6 +16949,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_log(params, tensor); ggml_compute_forward_log(params, tensor);
} break; } break;
case GGML_OP_SIN:
{
ggml_compute_forward_sin(params, tensor);
} break;
case GGML_OP_COS:
{
ggml_compute_forward_cos(params, tensor);
} break;
case GGML_OP_SUM: case GGML_OP_SUM:
{ {
ggml_compute_forward_sum(params, tensor); ggml_compute_forward_sum(params, tensor);
@ -17433,6 +17603,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table); zero_table);
} }
} break; } break;
case GGML_OP_SIN:
{
if (src0->grad) {
src0->grad =
ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx,
tensor->grad,
ggml_cos(ctx, src0)),
zero_table);
}
} break;
case GGML_OP_COS:
{
if (src0->grad) {
src0->grad =
ggml_sub_or_set(ctx,
src0->grad,
ggml_mul(ctx,
tensor->grad,
ggml_sin(ctx, src0)),
zero_table);
}
} break;
case GGML_OP_SUM: case GGML_OP_SUM:
{ {
if (src0->grad) { if (src0->grad) {
@ -18520,6 +18714,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_SQRT: case GGML_OP_SQRT:
case GGML_OP_LOG: case GGML_OP_LOG:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_SUM: case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN: