From 1f35ce61c114c076d2b98633fcd86b0080f32a80 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 May 2024 20:17:31 +0300 Subject: [PATCH] ggml : fix YARN + add tests + add asserts (llama/7617) * tests : add rope tests ggml-ci * ggml : fixes (hopefully) ggml-ci * tests : add non-cont tests ggml-ci * cuda : add asserts for rope/norm + fix DS2 ggml-ci * ggml : assert contiguousness * tests : reduce RoPE tests ggml-ci --- ggml-cuda.cu | 4 ++- ggml-cuda/norm.cu | 6 ++++ ggml-cuda/rope.cu | 18 +++++------- ggml-kompute.cpp | 4 ++- ggml-metal.m | 8 ++++- ggml-metal.metal | 16 ++++------ ggml-sycl.cpp | 2 +- ggml.c | 74 +++++++++++++++++++++++------------------------ ggml.h | 6 +++- 9 files changed, 75 insertions(+), 63 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d0a754ee..1172f7b2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else - if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( @@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + return true; case GGML_OP_ROPE: + return ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-cuda/norm.cu b/ggml-cuda/norm.cu index 86f77453..30866d51 100644 --- a/ggml-cuda/norm.cu +++ b/ggml-cuda/norm.cu @@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index 50f2cf41..0dd07977 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -61,7 +61,7 @@ static __global__ void rope( template static __global__ void rope_neox( const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -85,15 +85,13 @@ static __global__ void rope_neox( const int i = row*ncols + ib*n_dims + ic/2; const int i2 = row/p_delta_rows; - float cur_rot = inv_ndims * ic - ib; - const int p = has_pos ? pos[i2] : 0; const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor; + const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor; float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + n_dims/2]; @@ -174,30 +172,29 @@ static void rope_neox_cuda( const dim3 block_nums(nrows, num_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.0f / n_dims; if (pos == nullptr) { if (freq_factors == nullptr) { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } else { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } } else { if (freq_factors == nullptr) { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } else { rope_neox<<>>( x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims, freq_factors + theta_scale, freq_factors ); } } @@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 6c6058b2..ed59d2be 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { GGML_ASSERT(ne00 == ne10); - // TODO: assert that dim2 and dim3 are contiguous + ggml_is_contiguous_2(src0); + ggml_is_contiguous_2(src1); + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); diff --git a/ggml-metal.m b/ggml-metal.m index 4ba498e8..a7e13bdc 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_ASSERT(ne00 == ne10); - // TODO: assert that dim2 and dim3 are contiguous + ggml_is_contiguous_2(src0); + ggml_is_contiguous_2(src1); + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -2187,6 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -2214,6 +2217,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_GROUP_NORM: { GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous(src0)); //float eps; //memcpy(&eps, dst->op_params, sizeof(float)); @@ -2247,6 +2251,8 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_NORM: { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); diff --git a/ggml-metal.metal b/ggml-metal.metal index b16f2b7e..0cb85e1a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1767,13 +1767,13 @@ kernel void kernel_rope( const int64_t p = pos[i2]; - const float theta_0 = (float)p; + const float theta_base = (float)p; const float inv_ndims = -1.f/n_dims; if (!is_neox) { for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + const float theta = theta_base * pow(freq_base, inv_ndims*i0); - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -1789,18 +1789,14 @@ kernel void kernel_rope( } else { for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { if (ic < n_dims) { - const int64_t ib = 0; + const int64_t i0 = ic/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; - const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + const float freq_factor = src2 != src0 ? src2[i0] : 1.0f; - const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; + const float theta = theta_base * pow(freq_base, inv_ndims*ic); float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const int64_t i0 = ib*n_dims + ic/2; + rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a7344813..5cd97e4f 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; - if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, diff --git a/ggml.c b/ggml.c index e6e2397b..b2b725f6 100644 --- a/ggml.c +++ b/ggml.c @@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { +GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { + return ggml_is_contiguous(tensor); +} + +GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return @@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * grad = dst->src[1]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(grad)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, grad)); @@ -14358,7 +14370,7 @@ static void ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); @@ -14407,7 +14419,7 @@ static void ggml_compute_forward_rope_f32( const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta) * sin_sign; - theta_base *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -14442,29 +14454,22 @@ static void ggml_compute_forward_rope_f32( dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py for (int64_t ic = 0; ic < ne0; ic += 2) { if (ic < n_dims) { - const int64_t ib = 0; + const int64_t i0 = ic/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; - float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; + const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta ); - sin_theta *= sin_sign; + sin_theta *= sin_sign; theta_base *= theta_scale; - const int64_t i0 = ib*n_dims + ic/2; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -14543,7 +14548,7 @@ static void ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); @@ -14592,7 +14597,7 @@ static void ggml_compute_forward_rope_f16( const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta) * sin_sign; - theta_base *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -14623,29 +14628,22 @@ static void ggml_compute_forward_rope_f16( dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py for (int64_t ic = 0; ic < ne0; ic += 2) { if (ic < n_dims) { - const int64_t ib = 0; + const int64_t i0 = ic/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; - float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f; + const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f; float cos_theta, sin_theta; rope_yarn( - theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta ); - sin_theta *= sin_sign; + sin_theta *= sin_sign; theta_base *= theta_scale; - const int64_t i0 = ib*n_dims + ic/2; - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); diff --git a/ggml.h b/ggml.h index f9deac7e..f3869969 100644 --- a/ggml.h +++ b/ggml.h @@ -756,7 +756,6 @@ extern "C" { GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); @@ -765,6 +764,11 @@ extern "C" { GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor); + GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() + GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 + GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);