mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-03-18 18:15:45 +00:00
CUDA: backwards pass for misc. ops, add tests (llama/11257)
* CUDA: backwards pass for misc. ops, add tests * remove restrict from pointers
This commit is contained in:
parent
db6383094c
commit
de49024e49
@ -1384,16 +1384,20 @@ extern "C" {
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
// rotary position embedding
|
||||
// if (mode & 1) - skip n_past elements (NOT SUPPORTED)
|
||||
|
@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
|
||||
return true;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_SCALE:
|
||||
@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_SILU_BACK:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
return true;
|
||||
|
||||
default:
|
||||
|
@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * grad = dst->src[1];
|
||||
const struct ggml_tensor * grad = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
assert(ggml_is_contiguous_1(grad));
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(src1));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
assert(ggml_are_same_shape(src0, grad));
|
||||
assert(ggml_are_same_shape(src1, dst));
|
||||
assert(ggml_are_same_shape(src1, grad));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nrows(src0);
|
||||
const int nc = src1->ne[0];
|
||||
const int nr = ggml_nrows(src1);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32(
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
ggml_vec_silu_backward_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])),
|
||||
(float *) ((char *) src1->data + i1*(src1->nb[1])),
|
||||
(float *) ((char *) grad->data + i1*(grad->nb[1])));
|
||||
|
||||
#ifndef NDEBUG
|
||||
@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32(
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps > 0.0f);
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
GGML_ASSERT(eps > 0.0f);
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
||||
const int64_t i12 = i02;
|
||||
const int64_t i13 = i03;
|
||||
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
|
||||
ggml_float sum_xx = 0.0;
|
||||
ggml_float sum_xdz = 0.0;
|
||||
@ -7066,9 +7067,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
||||
{
|
||||
// z = rms_norm(x)
|
||||
//
|
||||
// rms_norm(src0) =
|
||||
// rms_norm(src1) =
|
||||
// scale(
|
||||
// src0,
|
||||
// src1,
|
||||
// div(
|
||||
// 1,
|
||||
// sqrt(
|
||||
@ -7076,13 +7077,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
||||
// scale(
|
||||
// sum(
|
||||
// sqr(
|
||||
// src0)),
|
||||
// src1)),
|
||||
// (1.0/N)),
|
||||
// eps))));
|
||||
|
||||
// postorder:
|
||||
// ## op args grad
|
||||
// 00 param src0 grad[#00]
|
||||
// 00 param src1 grad[#00]
|
||||
// 01 const 1
|
||||
// 02 sqr (#00) grad[#02]
|
||||
// 03 sum (#02) grad[#03]
|
||||
@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
||||
// dx := scale(dx, rrms)
|
||||
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
|
||||
ggml_vec_cpy_f32 (ne00, dx, x);
|
||||
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
|
||||
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
|
||||
@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne10);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne10);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
GGML_ASSERT(ne2 % ne02 == 0);
|
||||
GGML_ASSERT(ne3 % ne03 == 0);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
// dps == dst per src0, used for group query attention
|
||||
const int64_t dps2 = ne2 / ne02;
|
||||
const int64_t dps3 = ne3 / ne03;
|
||||
|
||||
for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
|
||||
const int64_t bir1 = MIN(bir + blck_1, ir1);
|
||||
for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
|
||||
@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
|
||||
const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
|
||||
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2 / dps2;
|
||||
const int64_t i03 = i3 / dps3;
|
||||
|
||||
//const int64_t i10 = i1;
|
||||
const int64_t i12 = i2;
|
||||
@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max(
|
||||
}
|
||||
|
||||
|
||||
// ggml_compute_forward_soft_max_back
|
||||
// ggml_compute_forward_soft_max_ext_back
|
||||
|
||||
static void ggml_compute_forward_soft_max_back_f32(
|
||||
static void ggml_compute_forward_soft_max_ext_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32(
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src1, dst));
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
GGML_ASSERT(max_bias == 0.0f);
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
|
||||
const int ith = params->ith;
|
||||
@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32(
|
||||
|
||||
// linear runtime, no additional memory
|
||||
float dot_y_dy = 0;
|
||||
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
|
||||
ggml_vec_cpy_f32 (nc, dx, dy);
|
||||
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
||||
ggml_vec_mul_f32 (nc, dx, dx, y);
|
||||
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
|
||||
ggml_vec_cpy_f32 (nc, dx, dy);
|
||||
ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
|
||||
ggml_vec_mul_f32 (nc, dx, dx, y);
|
||||
ggml_vec_scale_f32(nc, dx, scale);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32(
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_soft_max_back(
|
||||
static void ggml_compute_forward_soft_max_ext_back(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_soft_max_back_f32(params, dst);
|
||||
ggml_compute_forward_soft_max_ext_back_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32(
|
||||
const int64_t IH = is_2D ? ne1 : 1;
|
||||
const int64_t IW = ne0;
|
||||
|
||||
const int64_t KH = is_2D ? ne01 : 1;
|
||||
const int64_t KW = ne00;
|
||||
const int64_t KH = is_2D ? ne11 : 1;
|
||||
const int64_t KW = ne10;
|
||||
|
||||
const int64_t OH = is_2D ? ne12 : 1;
|
||||
const int64_t OW = ne11;
|
||||
const int64_t OH = is_2D ? ne02 : 1;
|
||||
const int64_t OW = ne01;
|
||||
|
||||
int ofs0 = is_2D ? nb3 : nb2;
|
||||
int ofs1 = is_2D ? nb2 : nb1;
|
||||
@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32(
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * const src_data = (const float *) src1->data
|
||||
const float * const grad_in = (const float *) src0->data
|
||||
+ (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
|
||||
grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
|
||||
grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
|
||||
}
|
||||
}
|
||||
float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
|
||||
@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * opt0 = dst->src[2];
|
||||
const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
|
||||
const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
|
||||
const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(opt0));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0f));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1f));
|
||||
GGML_ASSERT(ggml_is_contiguous(grad));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
|
||||
|
||||
const int64_t ith = params->ith;
|
||||
const int64_t nth = params->nth;
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
const int64_t nc = src0->ne[0];
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
const int64_t nc = src0f->ne[0];
|
||||
const int64_t nr = ggml_nrows(src0f);
|
||||
|
||||
// rows per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
|
||||
const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
|
||||
|
||||
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
||||
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
|
||||
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
|
||||
const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int64_t i = 0; i < nc; ++i) {
|
||||
@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||
// soft_max
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(nc, &max, s0);
|
||||
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
||||
const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
||||
assert(sum > 0.0);
|
||||
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
|
||||
|
||||
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
||||
// grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
|
||||
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
||||
ggml_vec_scale_f32(nc, ds0, d_by_nr);
|
||||
|
||||
@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
{
|
||||
ggml_compute_forward_soft_max_back(params, tensor);
|
||||
ggml_compute_forward_soft_max_ext_back(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
|
@ -403,6 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
||||
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
|
||||
case GGML_OP_MUL_MAT:
|
||||
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
|
||||
case GGML_OP_SOFT_MAX_BACK: {
|
||||
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
||||
|
||||
return max_bias == 0.0f;
|
||||
}
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
case GGML_OP_OUT_PROD:
|
||||
|
@ -5,95 +5,89 @@
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
|
||||
template <bool use_shared>
|
||||
static __global__ void cross_entropy_loss_f32(
|
||||
const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
|
||||
extern __shared__ float tmp[];
|
||||
|
||||
const int ne_tmp = WARP_SIZE*nclasses;
|
||||
|
||||
extern __shared__ float tmp_all[];
|
||||
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
|
||||
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
|
||||
|
||||
// Each warp first loads ne_tmp logits/labels into shared memory:
|
||||
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
|
||||
const int ig = i0*nclasses + i; // ig == i global
|
||||
|
||||
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
|
||||
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
|
||||
}
|
||||
|
||||
// Each thread in the warp then calculates the cross entropy loss for a single row.
|
||||
// TODO: pad in order to avoid shared memory bank conflicts.
|
||||
logits += int64_t(blockIdx.x)*nclasses;
|
||||
labels += int64_t(blockIdx.x)*nclasses;
|
||||
|
||||
// Find maximum for softmax:
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
|
||||
float max_logit = -INFINITY;
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
const float val = logits[i];
|
||||
max_logit = fmaxf(max_logit, val);
|
||||
|
||||
if (use_shared) {
|
||||
tmp[i] = val;
|
||||
}
|
||||
}
|
||||
max_logit = warp_reduce_max(max_logit);
|
||||
|
||||
// Calculate log(softmax(logits)) which is just logits - max:
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
float val = tmp_logits[lane_id*nclasses + i] - max;
|
||||
sum += expf(val);
|
||||
tmp_logits[lane_id*nclasses + i] = val;
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
const float logit_i = use_shared ? tmp[i] : logits[i];
|
||||
sum += expf(logit_i - max_logit);
|
||||
}
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum = logf(sum);
|
||||
|
||||
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
|
||||
float loss = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
const float logit_i = use_shared ? tmp[i] : logits[i];
|
||||
loss += (logit_i - max_logit - sum) * labels[i];
|
||||
}
|
||||
loss = -warp_reduce_sum(loss) / (float)k;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
tmp_all[warp_id] = loss;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
|
||||
loss = warp_reduce_sum(loss);
|
||||
|
||||
if (lane_id != 0) {
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[blockIdx.x] = loss;
|
||||
}
|
||||
|
||||
static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
|
||||
template <bool use_shared>
|
||||
static __global__ void cross_entropy_loss_back_f32(
|
||||
const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
|
||||
float * __restrict__ dst, const int nclasses) {
|
||||
extern __shared__ float tmp[];
|
||||
|
||||
logits += int64_t(blockIdx.x)*nclasses;
|
||||
labels += int64_t(blockIdx.x)*nclasses;
|
||||
dst += int64_t(blockIdx.x)*nclasses;
|
||||
|
||||
float maxval = -INFINITY;
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
const float val = logits[blockIdx.x*nclasses + i];
|
||||
const float val = logits[i];
|
||||
maxval = fmaxf(maxval, val);
|
||||
tmp[i] = val;
|
||||
|
||||
if (use_shared) {
|
||||
tmp[i] = val;
|
||||
}
|
||||
}
|
||||
maxval = warp_reduce_max(maxval);
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
const float val = expf(tmp[i] - maxval);
|
||||
const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
|
||||
sum += val;
|
||||
tmp[i] = val;
|
||||
|
||||
if (use_shared) {
|
||||
tmp[i] = val;
|
||||
} else {
|
||||
dst[i] = val;
|
||||
}
|
||||
}
|
||||
sum = warp_reduce_sum(sum);
|
||||
const float sm_scale = 1.0f/sum;
|
||||
|
||||
const float d_by_nrows = *loss/gridDim.x;
|
||||
const float d_by_nrows = *grad/gridDim.x;
|
||||
for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
|
||||
dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
|
||||
const float val = use_shared ? tmp[i] : dst[i];
|
||||
dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
|
||||
const dim3 blocks_dim(WARP_SIZE, 1, 1);
|
||||
const dim3 blocks_num(nrows, 1, 1);
|
||||
const size_t nbytes_shared = ne00*sizeof(float);
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shared_memory_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
shared_memory_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
} else {
|
||||
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Combine results from individual blocks:
|
||||
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * opt0 = dst->src[2];
|
||||
const ggml_tensor * grad = dst->src[0];
|
||||
const ggml_tensor * src0f = dst->src[1];
|
||||
const ggml_tensor * src1f = dst->src[2];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(opt0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1f->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( grad->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(opt0));
|
||||
GGML_ASSERT(ggml_is_scalar(grad));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0f));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1f));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0f, dst));
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t ne00 = src0f->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0f);
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const float * opt0_d = (const float *) opt0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
const float * grad_d = (const float *) grad->data;
|
||||
const float * src0f_d = (const float *) src0f->data;
|
||||
const float * src1f_d = (const float *) src1f->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const dim3 blocks_dim(WARP_SIZE, 1, 1);
|
||||
const dim3 blocks_num(nrows, 1, 1);
|
||||
const int shmem = ne00*sizeof(float);
|
||||
const size_t nbytes_shared = ne00*sizeof(float);
|
||||
|
||||
cross_entropy_loss_back_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, opt0_d, dst_d, ne00);
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
if (!shared_memory_limit_raised[id]) {
|
||||
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
|
||||
shared_memory_limit_raised[id] = true;
|
||||
}
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||
} else {
|
||||
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
|
||||
}
|
||||
}
|
||||
|
@ -3,15 +3,15 @@
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void k_get_rows(
|
||||
const void * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
|
||||
@ -22,10 +22,10 @@ static __global__ void k_get_rows(
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int ib = i00/qk; // block index
|
||||
const int iqs = (i00%qk)/qr; // quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
@ -39,15 +39,15 @@ static __global__ void k_get_rows(
|
||||
|
||||
template<typename src0_t, typename dst_t>
|
||||
static __global__ void k_get_rows_float(
|
||||
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
|
||||
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
|
||||
@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||
|
||||
dst_row[i00] = src0_row[i00];
|
||||
}
|
||||
|
||||
template<typename grad_t, typename dst_t>
|
||||
static __global__ void k_get_rows_back_float(
|
||||
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
|
||||
const int col = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int64_t i = 0; i < nrows_grad; ++i) {
|
||||
if (rows[i] != dst_row) {
|
||||
continue;
|
||||
}
|
||||
sum += grad[i*ncols + col];
|
||||
}
|
||||
|
||||
dst[dst_row*ncols + col] = sum;
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
static void get_rows_cuda(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
|
||||
template<typename src0_t>
|
||||
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
static void get_rows_cuda_float(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
src0_dd, src1_dd, dst_dd,
|
||||
ne00, /*ne01, ne02, ne03,*/
|
||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||
/* s0,*/ s1, s2, s3,
|
||||
/* nb00,*/ nb01, nb02, nb03,
|
||||
s10, s11, s12/*, s13*/);
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
|
||||
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
const void * src0_d = (const void *) src0->data;
|
||||
const int32_t * src1_d = (const int32_t *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
||||
|
||||
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
||||
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F16:
|
||||
get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_d, dst_d, stream);
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
|
||||
const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const int32_t * src1_d = (const int32_t *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_ASSERT(ne02*ne03 == 1);
|
||||
GGML_ASSERT(ne12*ne13 == 1);
|
||||
GGML_ASSERT(ne2*ne3 == 1);
|
||||
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne1, 1);
|
||||
|
||||
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_GET_ROWS_BLOCK_SIZE 256
|
||||
#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_cuda_op_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_cuda_op_get_rows_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
ggml_cuda_op_leaky_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
ggml_cuda_op_silu_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
ggml_cuda_op_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
ggml_cuda_op_rms_norm_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
|
||||
@ -2138,6 +2147,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SOFT_MAX:
|
||||
ggml_cuda_op_soft_max(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
ggml_cuda_op_soft_max_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
ggml_cuda_op_rope(ctx, dst);
|
||||
break;
|
||||
@ -2912,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
@ -2928,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
{
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@ -3001,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
@ -3027,6 +3047,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX_BACK: {
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
||||
return max_bias == 0.0f;
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK: {
|
||||
const size_t ts = ggml_type_size(op->src[0]->type);
|
||||
|
@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
x += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
|
||||
float2 mean_var = make_float2(0.0f, 0.0f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
const float xi = x[col];
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float2 s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = mean_var;
|
||||
}
|
||||
@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
|
||||
dst[col] = (x[col] - mean) * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,14 +44,8 @@ template <int block_size>
|
||||
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
||||
// blockIdx.x: num_groups idx
|
||||
// threadIdx.x: block_size idx
|
||||
int start = blockIdx.x * group_size;
|
||||
int end = start + group_size;
|
||||
|
||||
start += threadIdx.x;
|
||||
|
||||
if (end >= ne_elements) {
|
||||
end = ne_elements;
|
||||
}
|
||||
const int start = blockIdx.x*group_size + threadIdx.x;
|
||||
const int end = min(blockIdx.x*group_size + group_size, ne_elements);
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
}
|
||||
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
float mean = tmp / group_size;
|
||||
const float mean = tmp / group_size;
|
||||
tmp = 0.0f;
|
||||
|
||||
for (int j = start; j < end; j += block_size) {
|
||||
float xi = x[j] - mean;
|
||||
const float xi = x[j] - mean;
|
||||
dst[j] = xi;
|
||||
tmp += xi * xi;
|
||||
}
|
||||
@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
float variance = tmp / group_size;
|
||||
float scale = rsqrtf(variance + eps);
|
||||
const float variance = tmp / group_size;
|
||||
const float scale = rsqrtf(variance + eps);
|
||||
for (int j = start; j < end; j += block_size) {
|
||||
dst[j] *= scale;
|
||||
}
|
||||
@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
x += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_back_f32(
|
||||
const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
grad += int64_t(row)*ncols;
|
||||
xf += int64_t(row)*ncols;
|
||||
dst += int64_t(row)*ncols;
|
||||
|
||||
float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
|
||||
float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xfi = xf[col];
|
||||
sum_xx += xfi * xfi;
|
||||
sum_xg += xfi * grad[col];
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
sum_xx = warp_reduce_sum(sum_xx);
|
||||
sum_xg = warp_reduce_sum(sum_xg);
|
||||
if constexpr (block_size > WARP_SIZE) {
|
||||
static_assert(block_size == 1024, "unexpected block_size");
|
||||
__shared__ float s_sum_xx[32];
|
||||
__shared__ float s_sum_xg[32];
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum_xx[warp_id] = sum_xx;
|
||||
s_sum_xg[warp_id] = sum_xg;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
sum_xx = s_sum_xx[lane_id];
|
||||
sum_xx = warp_reduce_sum(sum_xx);
|
||||
|
||||
sum_xg = s_sum_xg[lane_id];
|
||||
sum_xg = warp_reduce_sum(sum_xg);
|
||||
}
|
||||
|
||||
const float mean_eps = sum_xx / ncols + eps;
|
||||
const float sum_eps = sum_xx + ncols*eps;
|
||||
|
||||
const float scale_grad = rsqrtf(mean_eps);
|
||||
const float scale_x = -scale_grad * sum_xg/sum_eps;
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale_grad*grad[col] + scale_x*xf[col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
|
||||
}
|
||||
}
|
||||
|
||||
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
|
||||
static void group_norm_f32_cuda(
|
||||
const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
|
||||
if (group_size < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
||||
@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
|
||||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
|
||||
@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * grad = dst->src[0]; // gradients
|
||||
const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
|
||||
|
||||
const float * grad_d = (const float *) grad->data;
|
||||
const float * src0f_d = (const float *) src0f->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(grad));
|
||||
|
||||
GGML_ASSERT( grad->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0f->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0f->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0f);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||
}
|
||||
|
@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_ASSERT(ne01 == ne11);
|
||||
GGML_ASSERT(ne0 == ne00);
|
||||
GGML_ASSERT(ne1 == ne10);
|
||||
|
||||
GGML_ASSERT(ne2 == src0->ne[2]);
|
||||
GGML_ASSERT(ne2 % src0->ne[2] == 0);
|
||||
GGML_ASSERT(ne3 % src0->ne[3] == 0);
|
||||
|
||||
GGML_ASSERT(ne2 == src1->ne[2]);
|
||||
GGML_ASSERT(ne3 == src0->ne[3]);
|
||||
GGML_ASSERT(ne3 == src1->ne[3]);
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
GGML_ASSERT(ne2 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
||||
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d, ne00,
|
||||
src1_d, ldb,
|
||||
&beta, dst_d, ne0));
|
||||
// data strides in dimensions 2/3
|
||||
const size_t s02 = nb02 / sizeof(float);
|
||||
const size_t s03 = nb03 / sizeof(float);
|
||||
const size_t s12 = nb12 / sizeof(float);
|
||||
const size_t s13 = nb13 / sizeof(float);
|
||||
const size_t s2 = nb2 / sizeof(float);
|
||||
const size_t s3 = nb3 / sizeof(float);
|
||||
|
||||
// dps == dst per src0, used for group query attention
|
||||
const int64_t dps2 = ne2 / ne02;
|
||||
const int64_t dps3 = ne3 / ne03;
|
||||
|
||||
// TODO batched matrix multiplication
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -39,9 +39,9 @@ static __device__ void rope_yarn(
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_norm(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@ -83,9 +83,9 @@ static __global__ void rope_norm(
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_neox(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@ -127,9 +127,9 @@ static __global__ void rope_neox(
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_multi(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
|
||||
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@ -187,9 +187,9 @@ static __global__ void rope_multi(
|
||||
|
||||
template<bool forward, bool has_ff, typename T>
|
||||
static __global__ void rope_vision(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
|
||||
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
if (i0 >= ne0) {
|
||||
@ -234,9 +234,9 @@ static __global__ void rope_vision(
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_norm_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@ -257,9 +257,9 @@ static void rope_norm_cuda(
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_neox_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@ -280,9 +280,9 @@ static void rope_neox_cuda(
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_multi_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
@ -303,9 +303,9 @@ static void rope_multi_cuda(
|
||||
|
||||
template<bool forward, typename T>
|
||||
static void rope_vision_cuda(
|
||||
const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
|
||||
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
#include "softmax.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ float t2f32(T val) {
|
||||
@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
||||
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
||||
|
||||
x += int64_t(rowx)*ncols;
|
||||
mask += int64_t(rowy)*ncols * (mask != nullptr);
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
|
||||
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
||||
|
||||
float max_val = -INFINITY;
|
||||
|
||||
@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
||||
break;
|
||||
}
|
||||
|
||||
const int64_t ix = (int64_t)rowx*ncols + col;
|
||||
const int64_t iy = (int64_t)rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
|
||||
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t idst = (int64_t)rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void soft_max_back_f32(
|
||||
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
||||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
|
||||
grad += int64_t(rowx)*ncols;
|
||||
dstf += int64_t(rowx)*ncols;
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
dgf_dot += dstf[col]*grad[col];
|
||||
}
|
||||
|
||||
dgf_dot = warp_reduce_sum(dgf_dot);
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
|
||||
}
|
||||
}
|
||||
|
||||
@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
|
||||
const uint32_t n_head = nrows_x/nrows_y;
|
||||
@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
}
|
||||
}
|
||||
|
||||
static void soft_max_back_f32_cuda(
|
||||
const float * grad, const float * dstf, float * dst,
|
||||
const int ncols, const int nrows, const float scale, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
if (use_f16) {
|
||||
const half * src1_dd = (const half *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
} else {
|
||||
const float * src1_dd = (const float *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // grad
|
||||
const ggml_tensor * src1 = dst->src[1]; // forward pass output
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
GGML_ASSERT(max_bias == 0.0f);
|
||||
|
||||
soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
|
||||
}
|
||||
|
@ -3,3 +3,5 @@
|
||||
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
|
||||
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __global__ void silu_back_f32(
|
||||
const float * grad, const float * xf, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float xfi = xf[i];
|
||||
const float s = 1.0f / (1.0f + expf(-xfi));
|
||||
dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
|
||||
}
|
||||
|
||||
static __global__ void tanh_f32(const float * x, float * dst, int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
|
||||
silu_back_f32<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
|
||||
}
|
||||
|
||||
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
|
||||
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0]; // input from forward pass
|
||||
const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
@ -4,6 +4,7 @@
|
||||
#define CUDA_STEP_BLOCK_SIZE 256
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BACK_BLOCK_SIZE 256
|
||||
#define CUDA_TANH_BLOCK_SIZE 256
|
||||
#define CUDA_RELU_BLOCK_SIZE 256
|
||||
#define CUDA_SIGMOID_BLOCK_SIZE 256
|
||||
@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -3454,12 +3454,14 @@ struct ggml_tensor * ggml_soft_max_ext(
|
||||
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
|
||||
}
|
||||
|
||||
// ggml_soft_max_back
|
||||
// ggml_soft_max_ext_back
|
||||
|
||||
static struct ggml_tensor * ggml_soft_max_back_impl(
|
||||
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias,
|
||||
bool inplace) {
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
@ -3467,21 +3469,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
|
||||
memcpy((float *) result->op_params + 0, &scale, sizeof(float));
|
||||
memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_tensor * ggml_soft_max_ext_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_soft_max_back_impl(ctx, a, b, false);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||
struct ggml_tensor * ggml_soft_max_ext_back_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
return ggml_soft_max_back_impl(ctx, a, b, true);
|
||||
struct ggml_tensor * b,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
|
||||
}
|
||||
|
||||
// ggml_rope
|
||||
@ -5080,10 +5089,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||
GGML_ASSERT(ggml_is_scalar(c));
|
||||
GGML_ASSERT(ggml_is_scalar(a));
|
||||
GGML_ASSERT(ggml_are_same_shape(b, c));
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
|
||||
|
||||
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
|
||||
result->src[0] = a;
|
||||
@ -5262,7 +5271,7 @@ static void ggml_sub_or_set(
|
||||
}
|
||||
|
||||
static void ggml_compute_backward(
|
||||
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
|
||||
struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
|
||||
struct ggml_tensor * tensor = cgraph->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
|
||||
|
||||
@ -5406,7 +5415,7 @@ static void ggml_compute_backward(
|
||||
if (src0_needs_grads) {
|
||||
float eps;
|
||||
memcpy(&eps, tensor->op_params, sizeof(float));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT: {
|
||||
@ -5589,7 +5598,13 @@ static void ggml_compute_backward(
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
|
||||
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
|
||||
}
|
||||
GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
|
||||
} break;
|
||||
@ -5630,7 +5645,7 @@ static void ggml_compute_backward(
|
||||
const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
|
||||
const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
|
||||
|
||||
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
|
||||
ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_POOL_2D: {
|
||||
@ -5673,7 +5688,7 @@ static void ggml_compute_backward(
|
||||
} break;
|
||||
case GGML_UNARY_OP_SILU: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXP: {
|
||||
@ -5690,7 +5705,7 @@ static void ggml_compute_backward(
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
|
||||
}
|
||||
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
||||
} break;
|
||||
|
Loading…
x
Reference in New Issue
Block a user