cuda: unary ops as float + de-duplicate (ggml/1130)

This commit is contained in:
cmdr2 2025-03-03 20:51:31 +05:30 committed by Georgi Gerganov
parent 6ac8e6b2ce
commit 846e01b2c0
2 changed files with 210 additions and 645 deletions

View File

@ -1,20 +1,24 @@
#include "clamp.cuh"
static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
return fminf(fmaxf(x, min), max);
}
template <class T>
static __global__ void op_clamp(const T * x, T * dst, const T min, const T max, const int k) {
static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
}
template <class T>
static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
op_clamp<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}

View File

@ -1,447 +1,213 @@
#include "unary.cuh"
template <class T>
static __global__ void op_abs(const T * x, T * dst, const int k) {
static __device__ __forceinline__ float op_abs(float x) {
return fabsf(x);
}
static __device__ __forceinline__ float op_sgn(float x) {
return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f)));
}
static __device__ __forceinline__ float op_neg(float x) {
return -x;
}
static __device__ __forceinline__ float op_step(float x) {
return x > 0.0f;
}
static __device__ __forceinline__ float op_gelu(float x) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
static __device__ __forceinline__ float op_gelu_quick(float x) {
const float GELU_QUICK_COEF = -1.702f;
return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x)));
}
static __device__ __forceinline__ float op_silu(float x) {
return x / (1.0f + expf(-x));
}
static __device__ __forceinline__ float op_tanh(float x) {
return tanhf(x);
}
static __device__ __forceinline__ float op_relu(float x) {
return fmaxf(x, 0);
}
static __device__ __forceinline__ float op_sigmoid(float x) {
return 1.0f / (1.0f + expf(-x));
}
static __device__ __forceinline__ float op_hardsigmoid(float x) {
return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
}
static __device__ __forceinline__ float op_hardswish(float x) {
return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
}
static __device__ __forceinline__ float op_exp(float x) {
return expf(x);
}
static __device__ __forceinline__ float op_sqr(float x) {
return x * x;
}
static __device__ __forceinline__ float op_sqrt(float x) {
return sqrtf(x);
}
static __device__ __forceinline__ float op_sin(float x) {
return sinf(x);
}
static __device__ __forceinline__ float op_cos(float x) {
return cosf(x);
}
static __device__ __forceinline__ float op_log(float x) {
return logf(x);
}
template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fabsf(x[i]);
dst[i] = (T)op((float)x[i]);
}
template <class T>
static __global__ void op_sgn(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = (T)(x[i] > (T)0.f ? 1.f : ((x[i] < (T)0.f ? -1.f : 0.f)));
}
template <class T>
static __global__ void op_neg(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = -x[i];
}
template <class T>
static __global__ void op_step(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] > (T)0.0f;
}
template <class T>
static __global__ void op_gelu(const T * x, T * dst, const int k) {
const T GELU_COEF_A = 0.044715f;
const T SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
T xi = x[i];
dst[i] = (T)0.5f*xi*((T)1.0f + (T)tanhf(SQRT_2_OVER_PI*xi*((T)1.0f + GELU_COEF_A*xi*xi)));
}
template <class T>
static __global__ void op_gelu_quick(const T * x, T * dst, int k) {
const T GELU_QUICK_COEF = -1.702f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * ((T)1.0f / ((T)1.0f + (T)expf(GELU_QUICK_COEF * x[i])));
}
template <class T>
static __global__ void op_silu(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] / ((T)1.0f + (T)expf(-x[i]));
}
template <class T>
static __global__ void op_silu_back(
const T * grad, const T * xf, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const T xfi = xf[i];
const T s = (T)1.0f / ((T)1.0f + (T)expf(-xfi));
dst[i] = grad[i] * s * ((T)1.0f + xfi * ((T)1.0f - s));
}
template <class T>
static __global__ void op_tanh(const T * x, T * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = tanhf(x[i]);
}
template <class T>
static __global__ void op_relu(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fmaxf(x[i], 0);
}
template <class T>
static __global__ void op_sigmoid(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = (T)1.0f / ((T)1.0f + (T)expf(-x[i]));
}
template <class T>
static __global__ void op_hardsigmoid(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f));
}
template <class T>
static __global__ void op_hardswish(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * (T)fminf(1.0f, fmaxf(0.0f, (x[i] + (T)3.0f) / (T)6.0f));
}
template <class T>
static __global__ void op_exp(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = expf(x[i]);
}
template <class T>
static __global__ void op_leaky_relu(const T * x, T * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = (T)fmaxf(x[i], 0) + (T)fminf(x[i], 0.0f) * (T)negative_slope;
}
template <class T>
static __global__ void op_sqr(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * x[i];
}
template <class T>
static __global__ void op_sqrt(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = sqrtf(x[i]);
}
template <class T>
static __global__ void op_sin(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = sinf(x[i]);
}
template <class T>
static __global__ void op_cos(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = cosf(x[i]);
}
template <class T>
static __global__ void op_log(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = logf(x[i]);
}
template <class T>
static void abs_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
template <float (*op)(float), typename T>
static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
op_abs<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <float (*op)(float)>
void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_abs>(ctx, dst);
}
void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sgn>(ctx, dst);
}
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_neg>(ctx, dst);
}
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_step>(ctx, dst);
}
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu>(ctx, dst);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
}
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_silu>(ctx, dst);
}
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_tanh>(ctx, dst);
}
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_relu>(ctx, dst);
}
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
}
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst);
}
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_hardswish>(ctx, dst);
}
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_exp>(ctx, dst);
}
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sqr>(ctx, dst);
}
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sqrt>(ctx, dst);
}
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_sin>(ctx, dst);
}
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_cos>(ctx, dst);
}
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_log>(ctx, dst);
}
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
const float s = 1.0f / (1.0f + expf(-x));
return grad * s * (1.0f + x * (1.0f - s));
}
template <class T>
static void sgn_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
op_sgn<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
template <class T>
static void neg_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
op_neg<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
if (i >= k) {
return;
}
template <class T>
static void step_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
op_step<<<num_blocks, CUDA_STEP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void gelu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
op_gelu<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void gelu_quick_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
op_gelu_quick<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void silu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
op_silu<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]);
}
template <class T>
static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
op_silu_back<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
}
template <class T>
static void tanh_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
op_tanh<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void relu_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
op_relu<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void sigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
op_sigmoid<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void hardsigmoid_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
op_hardsigmoid<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void hardswish_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
op_hardswish<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void exp_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
op_exp<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
op_leaky_relu<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
}
template <class T>
static void sqr_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
op_sqr<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void sqrt_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
op_sqrt<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void sin_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
op_sin<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void cos_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
op_cos<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
template <class T>
static void log_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
op_log<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
abs_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
abs_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
sgn_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
sgn_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
neg_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
neg_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
step_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
step_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
gelu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
gelu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
silu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
silu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
silu_back_kernel<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
}
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -467,137 +233,27 @@ void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
}
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
/* leaky relu */
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
gelu_quick_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
gelu_quick_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) {
return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope;
}
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
template <class T>
static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
tanh_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
tanh_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
if (i >= k) {
return;
}
dst[i] = (T)op_leaky_relu((float)x[i], negative_slope);
}
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
sigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
sigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
hardsigmoid_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
hardsigmoid_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
hardswish_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
hardswish_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
exp_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
exp_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
template <class T>
static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_kernel<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
}
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -621,98 +277,3 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
}
}
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
sqr_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
sqr_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
sqrt_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
sqrt_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
sin_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
sin_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
cos_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
cos_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
void * dst_d = dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
if (src0->type == GGML_TYPE_F16) {
log_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
} else {
log_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
}
}