diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a9c051cd..1198dc1f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1384,16 +1384,20 @@ extern "C" { float scale, float max_bias); - GGML_API struct ggml_tensor * ggml_soft_max_back( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // rotary position embedding // if (mode & 1) - skip n_past elements (NOT SUPPORTED) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 8dc8226a..9a3bf9f2 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +// ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: @@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_LOG: case GGML_OP_UNARY: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: return true; default: diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8bf5f781..8040e20f 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -6691,20 +6691,20 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * grad = dst->src[1]; + const struct ggml_tensor * grad = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; assert(ggml_is_contiguous_1(grad)); - assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(src1)); assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - assert(ggml_are_same_shape(src0, grad)); + assert(ggml_are_same_shape(src1, dst)); + assert(ggml_are_same_shape(src1, grad)); const int ith = params->ith; const int nth = params->nth; - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); + const int nc = src1->ne[0]; + const int nr = ggml_nrows(src1); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -6716,7 +6716,7 @@ static void ggml_compute_forward_silu_back_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_silu_backward_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) src1->data + i1*(src1->nb[1])), (float *) ((char *) grad->data + i1*(grad->nb[1]))); #ifndef NDEBUG @@ -6895,7 +6895,7 @@ static void ggml_compute_forward_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - GGML_ASSERT(eps > 0.0f); + GGML_ASSERT(eps >= 0.0f); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -6966,7 +6966,7 @@ static void ggml_compute_forward_rms_norm_f32( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - GGML_ASSERT(eps > 0.0f); + GGML_ASSERT(eps >= 0.0f); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { @@ -7018,12 +7018,13 @@ static void ggml_compute_forward_rms_norm_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -7042,8 +7043,8 @@ static void ggml_compute_forward_rms_norm_back_f32( const int64_t i12 = i02; const int64_t i13 = i03; - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); ggml_float sum_xx = 0.0; ggml_float sum_xdz = 0.0; @@ -7066,9 +7067,9 @@ static void ggml_compute_forward_rms_norm_back_f32( { // z = rms_norm(x) // - // rms_norm(src0) = + // rms_norm(src1) = // scale( - // src0, + // src1, // div( // 1, // sqrt( @@ -7076,13 +7077,13 @@ static void ggml_compute_forward_rms_norm_back_f32( // scale( // sum( // sqr( - // src0)), + // src1)), // (1.0/N)), // eps)))); // postorder: // ## op args grad - // 00 param src0 grad[#00] + // 00 param src1 grad[#00] // 01 const 1 // 02 sqr (#00) grad[#02] // 03 sum (#02) grad[#03] @@ -7159,6 +7160,7 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) ggml_vec_cpy_f32 (ne00, dx, x); // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); @@ -7750,12 +7752,13 @@ static void ggml_compute_forward_out_prod_f32( const int ith = params->ith; const int nth = params->nth; - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne3 == ne13); - GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + GGML_ASSERT(ne2 % ne02 == 0); + GGML_ASSERT(ne3 % ne03 == 0); // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == sizeof(float)); @@ -7797,6 +7800,10 @@ static void ggml_compute_forward_out_prod_f32( const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); const int64_t blck_1 = 16; + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { const int64_t bir1 = MIN(bir + blck_1, ir1); for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { @@ -7807,8 +7814,8 @@ static void ggml_compute_forward_out_prod_f32( const int64_t i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - const int64_t i02 = i2; - const int64_t i03 = i3; + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; //const int64_t i10 = i1; const int64_t i12 = i2; @@ -8906,9 +8913,9 @@ static void ggml_compute_forward_soft_max( } -// ggml_compute_forward_soft_max_back +// ggml_compute_forward_soft_max_ext_back -static void ggml_compute_forward_soft_max_back_f32( +static void ggml_compute_forward_soft_max_ext_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8921,6 +8928,14 @@ static void ggml_compute_forward_soft_max_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src1, dst)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + // TODO: handle transposed/permuted matrices const int ith = params->ith; @@ -8969,10 +8984,11 @@ static void ggml_compute_forward_soft_max_back_f32( // linear runtime, no additional memory float dot_y_dy = 0; - ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); - ggml_vec_cpy_f32 (nc, dx, dy); - ggml_vec_acc1_f32(nc, dx, -dot_y_dy); - ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32 (nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_scale_f32(nc, dx, scale); #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -8983,7 +8999,7 @@ static void ggml_compute_forward_soft_max_back_f32( } } -static void ggml_compute_forward_soft_max_back( +static void ggml_compute_forward_soft_max_ext_back( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -8992,7 +9008,7 @@ static void ggml_compute_forward_soft_max_back( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_soft_max_back_f32(params, dst); + ggml_compute_forward_soft_max_ext_back_f32(params, dst); } break; default: { @@ -9985,9 +10001,10 @@ static void ggml_compute_forward_im2col_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -10009,11 +10026,11 @@ static void ggml_compute_forward_im2col_back_f32( const int64_t IH = is_2D ? ne1 : 1; const int64_t IW = ne0; - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; + const int64_t KH = is_2D ? ne11 : 1; + const int64_t KW = ne10; - const int64_t OH = is_2D ? ne12 : 1; - const int64_t OW = ne11; + const int64_t OH = is_2D ? ne02 : 1; + const int64_t OW = ne01; int ofs0 = is_2D ? nb3 : nb2; int ofs1 = is_2D ? nb2 : nb1; @@ -10059,9 +10076,9 @@ static void ggml_compute_forward_im2col_back_f32( continue; } - const float * const src_data = (const float *) src1->data + const float * const grad_in = (const float *) src0->data + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; + grad += grad_in[iic*(KH*KW) + ikh*KW + ikw]; } } float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] @@ -12484,22 +12501,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * opt0 = dst->src[2]; + const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output + const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass + const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst)); const int64_t ith = params->ith; const int64_t nth = params->nth; // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = ggml_nrows(src0); + const int64_t nc = src0f->ne[0]; + const int64_t nr = ggml_nrows(src0f); // rows per thread const int64_t dr = (nr + nth - 1)/nth; @@ -12508,12 +12525,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; + const float d_by_nr = ((const float *) grad->data)[0] / (float) nr; for (int64_t i1 = ir0; i1 < ir1; i1++) { - float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]); + const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]); #ifndef NDEBUG for (int64_t i = 0; i < nc; ++i) { @@ -12526,11 +12543,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( // soft_max float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); + const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); assert(sum > 0.0); ggml_vec_scale_f32(nc, ds0, 1.0/sum); - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr ggml_vec_sub_f32(nc, ds0, ds0, s1); ggml_vec_scale_f32(nc, ds0, d_by_nr); @@ -12827,7 +12844,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SOFT_MAX_BACK: { - ggml_compute_forward_soft_max_back(params, tensor); + ggml_compute_forward_soft_max_ext_back(params, tensor); } break; case GGML_OP_ROPE: { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 5c47ceb7..35a1c876 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -403,6 +403,16 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float case GGML_OP_MUL_MAT: return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; + case GGML_OP_SOFT_MAX_BACK: { + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) { + return false; + } + float max_bias = 0.0f; + + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + + return max_bias == 0.0f; + } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; case GGML_OP_OUT_PROD: diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index ed09406a..27599a2b 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -5,95 +5,89 @@ #include #include -static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; +template +static __global__ void cross_entropy_loss_f32( + const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) { + extern __shared__ float tmp[]; - const int ne_tmp = WARP_SIZE*nclasses; - - extern __shared__ float tmp_all[]; - float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; - float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; - - // Each warp first loads ne_tmp logits/labels into shared memory: - for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { - const int ig = i0*nclasses + i; // ig == i global - - tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; - tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; - } - - // Each thread in the warp then calculates the cross entropy loss for a single row. - // TODO: pad in order to avoid shared memory bank conflicts. + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; // Find maximum for softmax: - float max = -INFINITY; - for (int i = 0; i < nclasses; ++i) { - max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + float max_logit = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + max_logit = fmaxf(max_logit, val); + + if (use_shared) { + tmp[i] = val; + } } + max_logit = warp_reduce_max(max_logit); // Calculate log(softmax(logits)) which is just logits - max: float sum = 0.0f; - for (int i = 0; i < nclasses; ++i) { - float val = tmp_logits[lane_id*nclasses + i] - max; - sum += expf(val); - tmp_logits[lane_id*nclasses + i] = val; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + sum += expf(logit_i - max_logit); } + sum = warp_reduce_sum(sum); sum = logf(sum); // log(exp(logits - max) / sum) = (logits - max) - log(sum) float loss = 0.0f; - for (int i = 0; i < nclasses; ++i) { - loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + loss += (logit_i - max_logit - sum) * labels[i]; } loss = -warp_reduce_sum(loss) / (float)k; - __syncthreads(); - - if (lane_id == 0) { - tmp_all[warp_id] = loss; - } - - __syncthreads(); - - if (warp_id != 0) { - return; - } - - loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; - loss = warp_reduce_sum(loss); - - if (lane_id != 0) { + if (threadIdx.x != 0) { return; } dst[blockIdx.x] = loss; } -static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) { +template +static __global__ void cross_entropy_loss_back_f32( + const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels, + float * __restrict__ dst, const int nclasses) { extern __shared__ float tmp[]; + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; + dst += int64_t(blockIdx.x)*nclasses; + float maxval = -INFINITY; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - const float val = logits[blockIdx.x*nclasses + i]; + const float val = logits[i]; maxval = fmaxf(maxval, val); - tmp[i] = val; + + if (use_shared) { + tmp[i] = val; + } } maxval = warp_reduce_max(maxval); float sum = 0.0f; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - const float val = expf(tmp[i] - maxval); + const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval); sum += val; - tmp[i] = val; + + if (use_shared) { + tmp[i] = val; + } else { + dst[i] = val; + } } sum = warp_reduce_sum(sum); const float sm_scale = 1.0f/sum; - const float d_by_nrows = *loss/gridDim.x; + const float d_by_nrows = *grad/gridDim.x; for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { - dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows; + const float val = use_shared ? tmp[i] : dst[i]; + dst[i] = (val*sm_scale - labels[i])*d_by_nrows; } } @@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool & pool = ctx.pool(); cudaStream_t stream = ctx.stream(); - const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); - cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } else { + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } + CUDA_CHECK(cudaGetLastError()); // Combine results from individual blocks: sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); } void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * opt0 = dst->src[2]; + const ggml_tensor * grad = dst->src[0]; + const ggml_tensor * src0f = dst->src[1]; + const ggml_tensor * src1f = dst->src[2]; - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(opt0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT(src1f->type == GGML_TYPE_F32); + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_is_scalar(grad)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f)); + GGML_ASSERT(ggml_are_same_shape(src0f, dst)); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); - const float * src0_d = (const float *) src0->data; - const float * src1_d = (const float *) src1->data; - const float * opt0_d = (const float *) opt0->data; - float * dst_d = (float *) dst->data; + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + const float * src1f_d = (const float *) src1f->data; + float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); const dim3 blocks_dim(WARP_SIZE, 1, 1); const dim3 blocks_num(nrows, 1, 1); - const int shmem = ne00*sizeof(float); + const size_t nbytes_shared = ne00*sizeof(float); - cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00); + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } else { + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4c370323..4cef53a9 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -3,15 +3,15 @@ template static __global__ void k_get_rows( - const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -22,10 +22,10 @@ static __global__ void k_get_rows( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index const int iybs = i00 - i00%qk; // dst block start index const int y_offset = qr == 1 ? 1 : qk/2; @@ -39,15 +39,15 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( - const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i00 = blockIdx.x*blockDim.x + threadIdx.x; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -58,14 +58,38 @@ static __global__ void k_get_rows_float( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); dst_row[i00] = src0_row[i00]; } +template +static __global__ void k_get_rows_back_float( + const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) { + const int col = blockIdx.x*blockDim.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int dst_row = blockIdx.y*blockDim.y + threadIdx.y; + + float sum = 0.0f; + + for (int64_t i = 0; i < nrows_grad; ++i) { + if (rows[i] != dst_row) { + continue; + } + sum += grad[i*ncols + col]; + } + + dst[dst_row*ncols + col] = sum; +} + template -static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg GGML_ASSERT(ne00 % 2 == 0); k_get_rows<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } template -static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda_float( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne13 == 1); + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const dim3 block_nums(block_num_x, ne10, ne11*ne12); @@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr //const size_t s13 = nb13 / ggml_element_size(src1); k_get_rows_float<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } @@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + + const void * src0_d = (const void *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - const int32_t * src1_i32 = (const int32_t *) src1_d; + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); switch (src0->type) { case GGML_TYPE_F16: - get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_F32: - get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q8_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; default: // TODO: k-quants @@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { break; } } + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(ne02*ne03 == 1); + GGML_ASSERT(ne12*ne13 == 1); + GGML_ASSERT(ne2*ne3 == 1); + + const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne1, 1); + + k_get_rows_back_float<<>>(src0_d, src1_d, dst_d, ne00, ne10); +} diff --git a/ggml/src/ggml-cuda/getrows.cuh b/ggml/src/ggml-cuda/getrows.cuh index bbf13023..a1ca643f 100644 --- a/ggml/src/ggml-cuda/getrows.cuh +++ b/ggml/src/ggml-cuda/getrows.cuh @@ -1,5 +1,8 @@ #include "common.cuh" #define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9118edc7..7fd1fc85 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2003,6 +2003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS: ggml_cuda_op_get_rows(ctx, dst); break; + case GGML_OP_GET_ROWS_BACK: + ggml_cuda_op_get_rows_back(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -2091,9 +2094,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_LEAKY_RELU: ggml_cuda_op_leaky_relu(ctx, dst); break; + case GGML_OP_SILU_BACK: + ggml_cuda_op_silu_back(ctx, dst); + break; case GGML_OP_RMS_NORM: ggml_cuda_op_rms_norm(ctx, dst); break; + case GGML_OP_RMS_NORM_BACK: + ggml_cuda_op_rms_norm_back(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); @@ -2138,6 +2147,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_cuda_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_cuda_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; @@ -2912,7 +2924,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } } break; case GGML_OP_OUT_PROD: - return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -2928,6 +2940,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } } break; + case GGML_OP_GET_ROWS_BACK: + { + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; @@ -3001,8 +3017,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } return false; } break; + case GGML_OP_SILU_BACK: + return ggml_is_contiguous(op->src[0]); + break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; break; case GGML_OP_NONE: @@ -3027,6 +3047,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { const size_t ts = ggml_type_size(op->src[0]->type); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f..d991ec97 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - float2 mean_var = make_float2(0.f, 0.f); + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float2 mean_var = make_float2(0.0f, 0.0f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; mean_var.x += xi; mean_var.y += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float2 s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = mean_var; } @@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const float inv_std = rsqrtf(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; + dst[col] = (x[col] - mean) * inv_std; } } @@ -40,14 +44,8 @@ template static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx // threadIdx.x: block_size idx - int start = blockIdx.x * group_size; - int end = start + group_size; - - start += threadIdx.x; - - if (end >= ne_elements) { - end = ne_elements; - } + const int start = blockIdx.x*group_size + threadIdx.x; + const int end = min(blockIdx.x*group_size + group_size, ne_elements); float tmp = 0.0f; // partial sum for thread in warp @@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float mean = tmp / group_size; + const float mean = tmp / group_size; tmp = 0.0f; for (int j = start; j < end; j += block_size) { - float xi = x[j] - mean; + const float xi = x[j] - mean; dst[j] = xi; tmp += xi * xi; } @@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float variance = tmp / group_size; - float scale = rsqrtf(variance + eps); + const float variance = tmp / group_size; + const float scale = rsqrtf(variance + eps); for (int j = start; j < end; j += block_size) { dst[j] *= scale; } @@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * x[row*ncols + col]; + dst[col] = scale * x[col]; + } +} + +template +static __global__ void rms_norm_back_f32( + const float * grad, const float * xf, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + grad += int64_t(row)*ncols; + xf += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass + float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + + for (int col = tid; col < ncols; col += block_size) { + const float xfi = xf[col]; + sum_xx += xfi * xfi; + sum_xg += xfi * grad[col]; + } + + // sum up partial sums + sum_xx = warp_reduce_sum(sum_xx); + sum_xg = warp_reduce_sum(sum_xg); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum_xx[32]; + __shared__ float s_sum_xg[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum_xx[warp_id] = sum_xx; + s_sum_xg[warp_id] = sum_xg; + } + __syncthreads(); + + sum_xx = s_sum_xx[lane_id]; + sum_xx = warp_reduce_sum(sum_xx); + + sum_xg = s_sum_xg[lane_id]; + sum_xg = warp_reduce_sum(sum_xg); + } + + const float mean_eps = sum_xx / ncols + eps; + const float sum_eps = sum_xx + ncols*eps; + + const float scale_grad = rsqrtf(mean_eps); + const float scale_x = -scale_grad * sum_xg/sum_eps; + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale_grad*grad[col] + scale_x*xf[col]; } } static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); norm_f32<<>>(x, dst, ncols, eps); @@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i } } -static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { +static void group_norm_f32_cuda( + const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { if (group_size < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); @@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); @@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_back_f32<<>>(grad, xf, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_back_f32<1024><<>>(grad, xf, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } @@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); @@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; // gradients + const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(grad)); + + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 431a8f74..d63d3438 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index 619cfdcb..73e3e2c4 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ne01 == ne11); GGML_ASSERT(ne0 == ne00); GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == src0->ne[2]); + GGML_ASSERT(ne2 % src0->ne[2] == 0); + GGML_ASSERT(ne3 % src0->ne[3] == 0); + GGML_ASSERT(ne2 == src1->ne[2]); - GGML_ASSERT(ne3 == src0->ne[3]); GGML_ASSERT(ne3 == src1->ne[3]); const float * src0_d = (const float *) src0->data; @@ -33,8 +32,6 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float alpha = 1.0f; const float beta = 0.0f; - GGML_ASSERT(ne2 == 1); - GGML_ASSERT(ne3 == 1); CUBLAS_CHECK(cublasSetStream(handle, stream)); const bool src1_T = ggml_is_transposed(src1); @@ -42,10 +39,27 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); - CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, - ne0, ne1, ne01, - &alpha, src0_d, ne00, - src1_d, ldb, - &beta, dst_d, ne0)); + // data strides in dimensions 2/3 + const size_t s02 = nb02 / sizeof(float); + const size_t s03 = nb03 / sizeof(float); + const size_t s12 = nb12 / sizeof(float); + const size_t s13 = nb13 / sizeof(float); + const size_t s2 = nb2 / sizeof(float); + const size_t s3 = nb3 / sizeof(float); + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + // TODO batched matrix multiplication + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ne0)); + } + } } diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index e1912fee..18f691b2 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -39,9 +39,9 @@ static __device__ void rope_yarn( template static __global__ void rope_norm( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -83,9 +83,9 @@ static __global__ void rope_norm( template static __global__ void rope_neox( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -127,9 +127,9 @@ static __global__ void rope_neox( template static __global__ void rope_multi( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -187,9 +187,9 @@ static __global__ void rope_multi( template static __global__ void rope_vision( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * __restrict__ pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * __restrict__ freq_factors, const mrope_sections sections) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -234,9 +234,9 @@ static __global__ void rope_vision( template static void rope_norm_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -257,9 +257,9 @@ static void rope_norm_cuda( template static void rope_neox_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -280,9 +280,9 @@ static void rope_neox_cuda( template static void rope_multi_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -303,9 +303,9 @@ static void rope_multi_cuda( template static void rope_vision_cuda( - const T * __restrict__ x, T * __restrict__ dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * __restrict__ pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * __restrict__ freq_factors, const mrope_sections sections, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c24abae1..9aa4b848 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,5 +1,7 @@ #include "common.cuh" +#include "ggml.h" #include "softmax.cuh" +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -11,14 +13,20 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } -template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __global__ void soft_max_f32( + const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, + const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + x += int64_t(rowx)*ncols; + mask += int64_t(rowy)*ncols * (mask != nullptr); + dst += int64_t(rowx)*ncols; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; @@ -29,7 +37,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; float max_val = -INFINITY; @@ -41,10 +49,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst break; } - const int64_t ix = (int64_t)rowx*ncols + col; - const int64_t iy = (int64_t)rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -110,8 +115,29 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst return; } - const int64_t idst = (int64_t)rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} + +static __global__ void soft_max_back_f32( + const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } @@ -121,7 +147,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); const uint32_t n_head = nrows_x/nrows_y; @@ -131,50 +157,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { + if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { - const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } +static void soft_max_back_f32_cuda( + const float * grad, const float * dstf, float * dst, + const int ncols, const int nrows, const float scale, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + soft_max_back_f32<<>>(grad, dstf, dst, ncols, scale); +} + void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const void * src1_d = src1 ? (const void *)src1->data : nullptr; + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + float * dst_d = (float *) dst->data; - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -189,18 +233,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { - const half * src1_dd = (const half *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { - const float * src1_dd = (const float *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 4ef4ff86..93dfee83 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -3,3 +3,5 @@ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 81fc9220..6b21f407 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void silu_back_f32( + const float * grad, const float * xf, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xfi = xf[i]; + const float s = 1.0f / (1.0f + expf(-xfi)); + dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s)); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_f32<<>>(grad, x, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // input from forward pass + const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index c9193672..e7f62643 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -4,6 +4,7 @@ #define CUDA_STEP_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_SILU_BACK_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_SIGMOID_BLOCK_SIZE 256 @@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0a0dcc7e..d83b1b8a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3454,12 +3454,14 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } -// ggml_soft_max_back +// ggml_soft_max_ext_back -static struct ggml_tensor * ggml_soft_max_back_impl( +static struct ggml_tensor * ggml_soft_max_ext_back_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + float scale, + float max_bias, bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -3467,21 +3469,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl( result->src[0] = a; result->src[1] = b; + memcpy((float *) result->op_params + 0, &scale, sizeof(float)); + memcpy((float *) result->op_params + 1, &max_bias, sizeof(float)); + return result; } -struct ggml_tensor * ggml_soft_max_back( +struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, false); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false); } -struct ggml_tensor * ggml_soft_max_back_inplace( +struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, true); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true); } // ggml_rope @@ -5080,10 +5089,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_tensor * c) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - GGML_ASSERT(ggml_is_scalar(c)); + GGML_ASSERT(ggml_is_scalar(a)); + GGML_ASSERT(ggml_are_same_shape(b, c)); - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_dup_tensor(ctx, b); result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->src[0] = a; @@ -5262,7 +5271,7 @@ static void ggml_sub_or_set( } static void ggml_compute_backward( - struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) { struct ggml_tensor * tensor = cgraph->nodes[i]; struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); @@ -5406,7 +5415,7 @@ static void ggml_compute_backward( if (src0_needs_grads) { float eps; memcpy(&eps, tensor->op_params, sizeof(float)); - ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps)); } } break; case GGML_OP_MUL_MAT: { @@ -5589,7 +5598,13 @@ static void ggml_compute_backward( } break; case GGML_OP_SOFT_MAX: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float)); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias)); } GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); } break; @@ -5630,7 +5645,7 @@ static void ggml_compute_backward( const int32_t d1 = ggml_get_op_params_i32(tensor, 5); const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; - ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); } } break; case GGML_OP_POOL_2D: { @@ -5673,7 +5688,7 @@ static void ggml_compute_backward( } break; case GGML_UNARY_OP_SILU: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0)); } } break; case GGML_UNARY_OP_EXP: { @@ -5690,7 +5705,7 @@ static void ggml_compute_backward( } break; case GGML_OP_CROSS_ENTROPY_LOSS: { if (src0_needs_grads) { - ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1)); } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break;