feat: add new sin and cos operators (ggml/919)

* ggml : add sin/cos operators

* ggml-cuda : add sin/cos operators

* ggml : add corresponding tests for sin/cos

* ggml : add backward computation for sin/cos operators

* ggml-vulkan : add sin/cos operators

* ggml-vulkan : add sin/cos shader source

* metal : add sin, cos

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Ronsor
2024-08-12 06:02:08 -07:00
committed by Georgi Gerganov
parent d65786ea54
commit 3643120690
8 changed files with 398 additions and 5 deletions

View File

@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
dst[i] = sqrtf(x[i]);
}
static __global__ void sin_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = sinf(x[i]);
}
static __global__ void cos_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = cosf(x[i]);
}
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}

View File

@ -9,6 +9,8 @@
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
#define CUDA_SIN_BLOCK_SIZE 256
#define CUDA_COS_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);