llama : add Qwen2VL support + multimodal RoPE (llama/10361)

* Barebone Qwen2VL LLM convertor

* Add Qwen2VL cli entrypoint

* [WIP] add qwen2vl arch

* Verify m-rope output

* Add vl-rope/2d-rope support for qwen2vl ViT

* update qwen2vl cli tool

* update 5D tensor op workaround

* [WIP] qwen2vl vision model

* make batch and clip utils compatible with qwen2vl

* [WIP] create inference workflow, gguf convert script but fix

* correcting vision-rope behavior, add the missing last layer back to ViT

* add arg parser to qwen2vl_surgery

* replace variable size array with vector

* cuda-gdb cmake preset

* add fp32 mrope, vision rope kernel

* add fp16 support for qwen2vl and m-rope

* add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION`

* fix rope op mode switching, out dated func args

* update `llama_hparams`

* update to keep up stream changes

* resolve linter, test errors

* add makefile entry, update speical image padding token

* add mrope unit test, fix few compiler warnings

* rename `mrope` related function, params

* minor updates on debug util, bug fixs

* add `m-rope` testcase to `test-backend-ops`

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix traililng whitespce

* store `llama_hparams.rope_sections` with fixed size array

* update position id tensor size check in GGML_OP_ROPE

* minor updates

* update `ggml_backend_*_supports_op` of unsupported backends

* remote old `rope_section` compare operator

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
HimariO 2024-12-14 20:43:46 +08:00 committed by Georgi Gerganov
parent 856fbaa92f
commit e22d38e4f2
9 changed files with 564 additions and 42 deletions

View File

@ -238,6 +238,8 @@
#define GGML_EXIT_ABORTED 1
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
#define GGUF_MAGIC "GGUF"
@ -1443,6 +1445,22 @@ extern "C" {
float beta_fast,
float beta_slow);
GGML_API struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,

View File

@ -1747,6 +1747,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (*ext_factor != 0) {
return false;
}
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return true;
}
case GGML_OP_UPSCALE: {

View File

@ -9133,6 +9133,64 @@ static void ggml_rope_cache_init(
}
}
static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta_t = theta_base_t;
float theta_h = theta_base_h;
float theta_w = theta_base_w;
float theta_e = theta_base_e; // extra position id for vision encoder
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
int sec_e = sections[2] + sec_w;
GGML_ASSERT(sect_dims <= ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
int sector = (i0 / 2) % sect_dims;
if (indep_sects) {
// compute theta independently for each dim sections
// (i.e. reset corresponding theta when `i0` go from one section to another)
if (sector == 0) {
theta_t = theta_base_t;
}
else if (sector == sections[0]) {
theta_h = theta_base_h;;
}
else if (sector == sec_w) {
theta_w = theta_base_w;
}
else if (sector == sec_e) {
theta_e = theta_base_e;
}
}
float theta = theta_t;
if (sector >= sections[0] && sector < sec_w) {
theta = theta_h;
}
else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w;
}
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
);
cache[i0 + 1] *= sin_sign;
theta_t *= theta_scale;
theta_w *= theta_scale;
theta_h *= theta_scale;
theta_e *= theta_scale;
}
}
static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@ -9143,6 +9201,7 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@ -9156,6 +9215,7 @@ static void ggml_compute_forward_rope_f32(
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
@ -9188,6 +9248,16 @@ static void ggml_compute_forward_rope_f32(
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL;
if (src2 != NULL) {
@ -9203,30 +9273,44 @@ static void ggml_compute_forward_rope_f32(
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_mrope) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
for (int64_t i1 = 0; i1 < ne1; i1++) {
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (!is_neox) {
if (is_neox || is_mrope) {
if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[1];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@ -9245,7 +9329,40 @@ static void ggml_compute_forward_rope_f32(
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
// fill the remain channels with data from src tensor
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -9256,6 +9373,7 @@ static void ggml_compute_forward_rope_f32(
}
}
}
}
}
// TODO: deduplicate f16/f32 code
@ -9269,6 +9387,7 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@ -9281,6 +9400,8 @@ static void ggml_compute_forward_rope_f16(
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
@ -9313,6 +9434,16 @@ static void ggml_compute_forward_rope_f16(
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL;
if (src2 != NULL) {
@ -9330,28 +9461,42 @@ static void ggml_compute_forward_rope_f16(
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_mrope) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (!is_neox) {
if (is_neox || is_mrope) {
if (is_vision) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[1]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
@ -9370,7 +9515,39 @@ static void ggml_compute_forward_rope_f16(
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -9381,6 +9558,7 @@ static void ggml_compute_forward_rope_f16(
}
}
}
}
}
static void ggml_compute_forward_rope(

View File

@ -4,6 +4,11 @@ struct rope_corr_dims {
float v[2];
};
struct mrope_sections {
int v[4];
};
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@ -108,6 +113,105 @@ static __global__ void rope_neox(
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
static __global__ void rope_multi(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= n_dims) {
const int i = row*ne0 + i0;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;
int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
int sec_w = sections.v[1] + sections.v[0];
int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_ff>
static __global__ void rope_vision(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows; // i2-th tokens
int sect_dims = sections.v[0] + sections.v[1];
int sec_w = sections.v[1] + sections.v[0];
int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[i2]*powf(theta_scale, p);
}
else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[i2 + ne2]*powf(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
float cos_theta;
float sin_theta;
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
}
template<typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@ -156,6 +260,56 @@ static void rope_neox_cuda(
}
}
template<typename T>
static void rope_multi_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_multi<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
} else {
rope_multi<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
}
}
template<typename T>
static void rope_vision_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) {
rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
} else {
rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
}
}
static void rope_norm_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@ -185,6 +339,38 @@ static void rope_neox_cuda_f32(
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_multi_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_multi_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_multi_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_multi_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_vision_cuda_f16(
const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_vision_cuda<half>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
static void rope_vision_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {
rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@ -201,8 +387,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
@ -210,6 +397,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
mrope_sections sections;
// RoPE alteration for extended context
float freq_base;
@ -225,8 +413,19 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}
const int32_t * pos = (const int32_t *) src1_d;
@ -253,6 +452,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
} else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
} else {
GGML_ABORT("fatal error");
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(

View File

@ -1419,8 +1419,18 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return true;
}
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:

View File

@ -1125,8 +1125,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
case GGML_OP_ARGMAX:
case GGML_OP_NORM:
case GGML_OP_ROPE:
return true;
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return true;
}
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
@ -3026,7 +3036,9 @@ static void ggml_metal_encode_node(
} break;
case GGML_OP_ROPE:
{
GGML_ASSERT(ne10 == ne02);
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
GGML_ASSERT(ne10 >= ne02);
const int nth = MIN(1024, ne00);

View File

@ -4488,7 +4488,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return ggml_is_contiguous(op->src[0]);
}
case GGML_OP_IM2COL:
// TODO: add support for the new F32 operations
return op->src[0]->type == GGML_TYPE_F16;

View File

@ -7687,7 +7687,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_REPEAT:
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return ggml_is_contiguous(op->src[0]);
}
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:

View File

@ -3517,15 +3517,18 @@ static struct ggml_tensor * ggml_rope_impl(
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
int sections[4] = {0, 0, 0, 0};
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &sections, sizeof(int)*4);
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
@ -3547,6 +3550,53 @@ struct ggml_tensor * ggml_rope(
);
}
struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
int n_dims,
int sections[4],
int mode,
int n_ctx_orig,
float freq_base,
float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow) {
// Multimodal Rotary Position Embedding
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32);
GGML_ASSERT(c->ne[0] >= n_dims / 2);
}
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(&params[11], sections, sizeof(int)*4);
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
result->src[0] = a;
result->src[1] = b;
result->src[2] = c;
return result;
}
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,