mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-31 00:24:07 +00:00
llama : add phi3 128K model support (llama/7225)
* add phi3 128k support in convert-hf-to-gguf * add phi3 128k support in cuda * address build warnings on llama.cpp * adjust index value in cuda long rope freq factors * add long rope support in ggml cpu backend * make freq factors only depend on ctx size * remove unused rope scaling type 'su' frin gguf converter * fix flint warnings on convert-hf-to-gguf.py * set to the short freq factor when context size is small than trained context size * add one line of comments * metal : support rope freq_factors * ggml : update ggml_rope_ext API to support freq. factors * backends : add dev messages to support rope freq. factors * minor : style * tests : update to use new rope API * backends : fix pragma semicolons * minor : cleanup * llama : move rope factors from KV header to tensors * llama : remove tmp assert * cuda : fix compile warning * convert : read/write n_head_kv * llama : fix uninitialized tensors --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
bbdbc3fc62
commit
c9dcb75118
@ -58,10 +58,10 @@ static __global__ void rope(
|
||||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
|
||||
template<typename T, bool has_pos>
|
||||
template<typename T, bool has_pos, bool has_freq_facs>
|
||||
static __global__ void rope_neox(
|
||||
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
|
||||
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
|
||||
) {
|
||||
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||
|
||||
@ -88,7 +88,9 @@ static __global__ void rope_neox(
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
|
||||
const int p = has_pos ? pos[i2] : 0;
|
||||
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
|
||||
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
|
||||
|
||||
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
@ -164,7 +166,7 @@ static void rope_cuda(
|
||||
template<typename T>
|
||||
static void rope_neox_cuda(
|
||||
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
||||
) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
|
||||
@ -175,15 +177,29 @@ static void rope_neox_cuda(
|
||||
const float inv_ndims = -1.0f / n_dims;
|
||||
|
||||
if (pos == nullptr) {
|
||||
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, inv_ndims
|
||||
);
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, inv_ndims, freq_factors
|
||||
);
|
||||
} else {
|
||||
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, inv_ndims, freq_factors
|
||||
);
|
||||
}
|
||||
} else {
|
||||
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, inv_ndims
|
||||
);
|
||||
if (freq_factors == nullptr) {
|
||||
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, inv_ndims, freq_factors
|
||||
);
|
||||
} else {
|
||||
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
|
||||
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, inv_ndims, freq_factors
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,24 +230,27 @@ static void rope_cuda_f32(
|
||||
|
||||
static void rope_neox_cuda_f16(
|
||||
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
|
||||
|
||||
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
|
||||
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
static void rope_neox_cuda_f32(
|
||||
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
|
||||
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
|
||||
) {
|
||||
|
||||
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
|
||||
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
const int32_t * pos = nullptr;
|
||||
if ((mode & 1) == 0) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(src1->ne[0] == ne2);
|
||||
pos = (const int32_t *) src1_d;
|
||||
}
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
if (is_neox) {
|
||||
pos = (const int32_t *) src1_d;
|
||||
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_neox_cuda_f32(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, stream
|
||||
attn_factor, corr_dims, freq_factors, stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_neox_cuda_f16(
|
||||
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, stream
|
||||
attn_factor, corr_dims, freq_factors, stream
|
||||
);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
#pragma message("TODO: implement phi3 frequency factors support")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
||||
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
||||
|
||||
GGML_ASSERT(ne10 == ne02);
|
||||
GGML_ASSERT(src0t == dstt);
|
||||
// const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
|
121
ggml-metal.m
121
ggml-metal.m
@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
||||
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
||||
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
||||
|
||||
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
||||
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
||||
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
||||
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
|
||||
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
||||
|
||||
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
||||
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
||||
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
||||
|
||||
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
||||
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
||||
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
||||
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
||||
|
||||
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
||||
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
||||
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
||||
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
||||
|
||||
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
||||
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
||||
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
||||
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
||||
|
||||
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
||||
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
||||
@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
const int n_as = src0->ne[2];
|
||||
|
||||
// src2 = ids
|
||||
const int64_t ne20 = src2->ne[0];
|
||||
const int64_t ne21 = src2->ne[1];
|
||||
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
||||
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
||||
|
||||
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2->nb[1];
|
||||
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
||||
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
||||
|
||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
||||
|
||||
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
||||
@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
||||
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
||||
|
||||
if (!is_neox) {
|
||||
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
||||
if (id_src2 != nil) {
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
||||
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
||||
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
||||
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
||||
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
||||
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
||||
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
||||
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
||||
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
||||
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
||||
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
||||
|
||||
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
||||
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
||||
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
||||
|
@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
|
||||
typedef void (rope_t)(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
@ -1675,6 +1676,7 @@ template<typename T>
|
||||
kernel void kernel_rope(
|
||||
device const void * src0,
|
||||
device const int32_t * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
@ -1744,8 +1746,10 @@ kernel void kernel_rope(
|
||||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
const float cur_rot = inv_ndims*ic - ib;
|
||||
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
|
||||
|
||||
const float theta = theta_0 * pow(freq_base, cur_rot);
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
|
@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const dpct::queue_ptr &main_stream) {
|
||||
#pragma message("TODO: implement phi3 frequency factors support")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
||||
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
#pragma message("TODO: implement phi3 frequency factors support")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
||||
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
||||
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
|
88
ggml.c
88
ggml.c
@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
|
||||
if (c) {
|
||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->ne[0] >= n_dims / 2);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl(
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope(
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
@ -6295,7 +6302,49 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||
int mode,
|
||||
int n_ctx) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
@ -6314,7 +6363,7 @@ struct ggml_tensor * ggml_rope_custom(
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
|
||||
);
|
||||
}
|
||||
@ -6334,27 +6383,18 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
float beta_fast,
|
||||
float beta_slow) {
|
||||
return ggml_rope_impl(
|
||||
ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
|
||||
);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down) {
|
||||
return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
|
||||
}
|
||||
|
||||
// ggml_rope_back
|
||||
|
||||
struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
|
||||
@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
const struct ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32(
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (is_neox) {
|
||||
if (src2 != NULL) {
|
||||
GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
|
||||
}
|
||||
|
||||
// backward process uses inverse rotation by cos and sin.
|
||||
// cos and sin build a rotation matrix, where the inverse is the transpose.
|
||||
// this essentially just switches the sign of sin.
|
||||
@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32(
|
||||
|
||||
// simplified from `(ib * n_dims + ic) * inv_ndims`
|
||||
float cur_rot = inv_ndims * ic - ib;
|
||||
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
|
||||
|
||||
float cos_theta, sin_theta;
|
||||
rope_yarn(
|
||||
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
||||
&cos_theta, &sin_theta
|
||||
);
|
||||
sin_theta *= sin_sign;
|
||||
@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
|
||||
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
|
||||
struct ggml_tensor * src0 = tensor->src[0];
|
||||
struct ggml_tensor * src1 = tensor->src[1];
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP:
|
||||
@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
ggml_rope_back(ctx,
|
||||
tensor->grad,
|
||||
src1,
|
||||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
ggml_rope_impl(ctx,
|
||||
tensor->grad,
|
||||
src1,
|
||||
src2,
|
||||
n_dims,
|
||||
mode,
|
||||
n_ctx,
|
||||
@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
masked);
|
||||
}
|
||||
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
const int64_t elem_q = ggml_nelements(src0);
|
||||
const int64_t elem_k = ggml_nelements(src1);
|
||||
const int64_t elem_v = ggml_nelements(src2);
|
||||
|
49
ggml.h
49
ggml.h
@ -1465,6 +1465,7 @@ extern "C" {
|
||||
// if mode & 4 == 1, ChatGLM style
|
||||
//
|
||||
// b is an int32 vector with size a->ne[2], it contains the positions
|
||||
// c is freq factors (e.g. phi3-128k), (optional)
|
||||
GGML_API struct ggml_tensor * ggml_rope(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -1483,10 +1484,11 @@ extern "C" {
|
||||
int n_ctx);
|
||||
|
||||
// custom RoPE
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom(
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
@ -1499,7 +1501,23 @@ extern "C" {
|
||||
float beta_slow);
|
||||
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
@ -1512,20 +1530,28 @@ extern "C" {
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow);
|
||||
float beta_slow),
|
||||
"use ggml_rope_ext instead");
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// xPos RoPE, in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int n_dims,
|
||||
float base,
|
||||
bool down);
|
||||
int mode,
|
||||
int n_ctx,
|
||||
int n_orig_ctx,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow),
|
||||
"use ggml_rope_ext_inplace instead");
|
||||
|
||||
// compute correction dims for YaRN RoPE scaling
|
||||
GGML_CALL void ggml_rope_yarn_corr_dims(
|
||||
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
||||
|
||||
// rotary position embedding backward, i.e compute dx from dy
|
||||
// a - dy
|
||||
@ -1533,6 +1559,7 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx,
|
||||
|
Loading…
x
Reference in New Issue
Block a user