mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-21 13:37:47 +00:00
feat: ref. cross entropy, add CUDA, fix grad test (ggml/929)
This commit is contained in:
parent
df06468d9e
commit
8954769aa2
@ -63,6 +63,7 @@ extern "C" {
|
||||
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
// "offset" refers to the offset of the tensor data for setting/getting data
|
||||
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
|
@ -9,8 +9,10 @@
|
||||
#include "ggml-cuda/binbcast.cuh"
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/dmmv.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
@ -29,7 +31,6 @@
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@ -2312,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -2619,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->buffer);
|
||||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
|
||||
}
|
||||
}
|
||||
@ -2902,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
}
|
||||
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
return true;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
default:
|
||||
return false;
|
||||
|
106
ggml/src/ggml-cuda/cross-entropy-loss.cu
Normal file
106
ggml/src/ggml-cuda/cross-entropy-loss.cu
Normal file
@ -0,0 +1,106 @@
|
||||
#include "common.cuh"
|
||||
#include "cross-entropy-loss.cuh"
|
||||
#include "sumrows.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
|
||||
|
||||
const int ne_tmp = WARP_SIZE*nclasses;
|
||||
|
||||
extern __shared__ float tmp_all[];
|
||||
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
|
||||
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
|
||||
|
||||
// Each warp first loads ne_tmp logits/labels into shared memory:
|
||||
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
|
||||
const int ig = i0*nclasses + i; // ig == i global
|
||||
|
||||
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
|
||||
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
|
||||
}
|
||||
|
||||
// Each thread in the warp then calculates the cross entropy loss for a single row.
|
||||
// TODO: pad in order to avoid shared memory bank conflicts.
|
||||
|
||||
// Find maximum for softmax:
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
|
||||
}
|
||||
|
||||
// Calculate log(softmax(logits)) which is just logits - max:
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
float val = tmp_logits[lane_id*nclasses + i] - max;
|
||||
sum += expf(val);
|
||||
tmp_logits[lane_id*nclasses + i] = val;
|
||||
}
|
||||
sum = logf(sum);
|
||||
|
||||
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
|
||||
float loss = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
|
||||
}
|
||||
loss = -warp_reduce_sum(loss) / (float)k;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
tmp_all[warp_id] = loss;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
|
||||
loss = warp_reduce_sum(loss);
|
||||
|
||||
if (lane_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[blockIdx.x] = loss;
|
||||
}
|
||||
|
||||
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
|
||||
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
|
||||
// Combine results from individual blocks:
|
||||
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/cross-entropy-loss.cuh
Normal file
5
ggml/src/ggml-cuda/cross-entropy-loss.cuh
Normal file
@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
|
||||
}
|
||||
}
|
||||
|
||||
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -2671,6 +2671,19 @@ static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x,
|
||||
return sum;
|
||||
}
|
||||
|
||||
static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
|
||||
// log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
|
||||
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - max;
|
||||
y[i] = val;
|
||||
sum += (ggml_float)expf(val);
|
||||
}
|
||||
return sum = (ggml_float)logf(sum);
|
||||
}
|
||||
|
||||
inline static float ggml_silu_backward_f32(float x, float dy) {
|
||||
const float s = 1.0f/(1.0f + expf(-x));
|
||||
return dy*s*(1.0f + x*(1.0f - s));
|
||||
@ -17022,8 +17035,6 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||
}
|
||||
ggml_barrier(params->shared);
|
||||
|
||||
const double eps = 1e-9;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
@ -17044,20 +17055,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
|
||||
}
|
||||
#endif
|
||||
|
||||
// soft_max
|
||||
float max = -INFINITY;
|
||||
ggml_vec_max_f32(nc, &max, s0);
|
||||
ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
|
||||
assert(sum > 0.0);
|
||||
sum = (1.0 - eps) / sum;
|
||||
ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max);
|
||||
assert(sum >= 0.0);
|
||||
|
||||
// avoid log(0) by rescaling from [0..1] to [eps..1]
|
||||
ggml_vec_scale_f32(nc, st, sum);
|
||||
ggml_vec_add1_f32(nc, st, st, eps);
|
||||
ggml_vec_log_f32(nc, st, st);
|
||||
ggml_vec_add1_f32(nc, st, st, -sum);
|
||||
ggml_vec_mul_f32(nc, st, st, s1);
|
||||
|
||||
float st_sum = 0;
|
||||
float st_sum = 0.0f;
|
||||
ggml_vec_sum_f32(nc, &st_sum, st);
|
||||
sums[ith] += st_sum;
|
||||
|
||||
@ -17114,8 +17120,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||
const int64_t ith = params->ith;
|
||||
const int64_t nth = params->nth;
|
||||
|
||||
const double eps = 1e-9;
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
const int64_t nc = src0->ne[0];
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
@ -17147,11 +17151,9 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
||||
ggml_vec_max_f32(nc, &max, s0);
|
||||
ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
|
||||
assert(sum > 0.0);
|
||||
sum = (1.0 - eps) / sum;
|
||||
ggml_vec_scale_f32(nc, ds0, 1.0/sum);
|
||||
|
||||
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
||||
ggml_vec_scale_f32(nc, ds0, sum);
|
||||
ggml_vec_add1_f32(nc, ds0, ds0, eps);
|
||||
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
||||
ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
|
||||
|
||||
@ -20287,6 +20289,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data) {
|
||||
GGML_ASSERT(ggml_is_scalar(f));
|
||||
GGML_ASSERT(f->type == GGML_TYPE_F32);
|
||||
|
||||
// these will store the parameters we want to optimize
|
||||
struct ggml_tensor * ps[GGML_MAX_PARAMS];
|
||||
|
Loading…
Reference in New Issue
Block a user