feat: add new sin and cos operators (ggml/919)

* ggml : add sin/cos operators

* ggml-cuda : add sin/cos operators

* ggml : add corresponding tests for sin/cos

* ggml : add backward computation for sin/cos operators

* ggml-vulkan : add sin/cos operators

* ggml-vulkan : add sin/cos shader source

* metal : add sin, cos

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Ronsor
2024-08-12 06:02:08 -07:00
committed by Georgi Gerganov
parent d65786ea54
commit 3643120690
8 changed files with 398 additions and 5 deletions

View File

@ -2310,7 +2310,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@ -2760,6 +2762,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SQR",
"SQRT",
"LOG",
"SIN",
"COS",
"SUM",
"SUM_ROWS",
"MEAN",
@ -2833,7 +2837,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@ -2848,6 +2852,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"x^2",
"√x",
"log(x)",
"sin(x)",
"cos(x)",
"Σx",
"Σx_k",
"Σx/n",
@ -2921,7 +2927,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4882,6 +4888,72 @@ struct ggml_tensor * ggml_log_inplace(
return ggml_log_impl(ctx, a, true);
}
// ggml_sin
static struct ggml_tensor * ggml_sin_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_SIN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_sin(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_sin_impl(ctx, a, false);
}
struct ggml_tensor * ggml_sin_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_sin_impl(ctx, a, true);
}
// ggml_cos
static struct ggml_tensor * ggml_cos_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (!inplace && (a->grad)) {
is_node = true;
}
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_COS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
struct ggml_tensor * ggml_cos(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_cos_impl(ctx, a, false);
}
struct ggml_tensor * ggml_cos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_cos_impl(ctx, a, true);
}
// ggml_sum
struct ggml_tensor * ggml_sum(
@ -10512,6 +10584,96 @@ static void ggml_compute_forward_log(
}
}
// ggml_compute_forward_sin
static void ggml_compute_forward_sin_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT( dst->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_sin_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_sin(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_sin_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_cos
static void ggml_compute_forward_cos_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
if (params->ith != 0) {
return;
}
GGML_ASSERT(ggml_are_same_shape(src0, dst));
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT( dst->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
ggml_vec_cos_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
(float *) ((char *) src0->data + i*(src0->nb[1])));
}
}
static void ggml_compute_forward_cos(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cos_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_sum
static void ggml_compute_forward_sum_f32(
@ -16787,6 +16949,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_log(params, tensor);
} break;
case GGML_OP_SIN:
{
ggml_compute_forward_sin(params, tensor);
} break;
case GGML_OP_COS:
{
ggml_compute_forward_cos(params, tensor);
} break;
case GGML_OP_SUM:
{
ggml_compute_forward_sum(params, tensor);
@ -17433,6 +17603,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
case GGML_OP_SIN:
{
if (src0->grad) {
src0->grad =
ggml_add_or_set(ctx,
src0->grad,
ggml_mul(ctx,
tensor->grad,
ggml_cos(ctx, src0)),
zero_table);
}
} break;
case GGML_OP_COS:
{
if (src0->grad) {
src0->grad =
ggml_sub_or_set(ctx,
src0->grad,
ggml_mul(ctx,
tensor->grad,
ggml_sin(ctx, src0)),
zero_table);
}
} break;
case GGML_OP_SUM:
{
if (src0->grad) {
@ -18520,6 +18714,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: