vulkan: support multi/vision rope, and noncontiguous rope (llama/11902)

This commit is contained in:
Jeff Bolz 2025-02-16 01:52:23 -06:00 committed by Georgi Gerganov
parent fcbcad0c90
commit 8a22a8b17f
7 changed files with 204 additions and 41 deletions

View File

@ -251,6 +251,8 @@ struct vk_device_struct {
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32;
vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
@ -494,6 +496,10 @@ struct vk_op_rope_push_constants {
float corr_dims[2];
float theta_scale;
uint32_t has_ff;
uint32_t ne02;
uint32_t s1;
uint32_t s2;
int32_t sections[4];
};
struct vk_op_soft_max_push_constants {
@ -2180,13 +2186,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
@ -5307,6 +5319,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
{
const int mode = ((const int32_t *) dst->op_params)[2];
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
@ -5315,6 +5329,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_neox_f16;
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_multi_f32;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_multi_f16;
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_vision_f32;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_vision_f16;
}
} else {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_norm_f32;
@ -5385,6 +5413,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
case GGML_OP_ROPE:
return true;
default:
return false;
@ -6149,7 +6178,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
const int n_dims = ((int32_t *) dst->op_params)[1];
// const int mode = ((int32_t *) dst->op_params)[2];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const float freq_base = ((float *) dst->op_params)[5];
@ -6158,16 +6187,24 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
const float attn_factor = ((float *) dst->op_params)[8];
const float beta_fast = ((float *) dst->op_params)[9];
const float beta_slow = ((float *) dst->op_params)[10];
int sections[4] {};
if (mode & GGML_ROPE_TYPE_MROPE) {
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
}
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr,
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
sections[0], sections[1], sections[2], sections[3],
}, dryrun);
}
@ -8264,16 +8301,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_REPEAT:
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return ggml_is_contiguous(op->src[0]);
}
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
@ -8831,7 +8858,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const float attn_factor = ((float *) tensor->op_params)[8];
const float beta_fast = ((float *) tensor->op_params)[9];
const float beta_slow = ((float *) tensor->op_params)[10];
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
if (mode & GGML_ROPE_TYPE_MROPE) {
int32_t *sections = ((int32_t *) tensor->op_params) + 11;
tensor_clone = ggml_rope_multi(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
} else {
tensor_clone = ggml_rope_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
}
} else if (tensor->op == GGML_OP_UNARY) {
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_SILU:

View File

@ -25,6 +25,10 @@ layout (push_constant) uniform parameter {
float corr_dims[2];
float theta_scale;
uint has_ff;
uint ne02;
uint s1;
uint s2;
int sections[4];
} p;
float rope_yarn_ramp(const float low, const float high, const uint i0) {

View File

@ -0,0 +1,60 @@
#version 450
#include "rope_head.comp"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;
if (i0 >= ne0) {
return;
}
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= p.sections[0] && sector < sec_w) {
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
}
else if (sector >= sec_w + p.sections[2]) {
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims/2]);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}

View File

@ -3,15 +3,18 @@
#include "rope_head.comp"
void main() {
const uint col = gl_GlobalInvocationID.y * 2;
const uint row = gl_GlobalInvocationID.x;
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
if (col >= p.ncols) {
if (i0 >= ne0) {
return;
}
if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@ -19,19 +22,22 @@ void main() {
return;
}
const uint i = row*p.ncols + col/2;
const uint i2 = row/p.p_delta_rows;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.n_dims/2]);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims/2]);
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
}

View File

@ -3,15 +3,18 @@
#include "rope_head.comp"
void main() {
const uint col = gl_GlobalInvocationID.y * 2;
const uint row = gl_GlobalInvocationID.x;
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
if (col >= p.ncols) {
if (i0 >= ne0) {
return;
}
if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
const uint row_dst = gl_GlobalInvocationID.x;
if (i0 >= p.n_dims) {
const uint i = row_dst*ne0 + i0;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
@ -19,19 +22,22 @@ void main() {
return;
}
const uint i = row*p.ncols + col;
const uint i2 = row/p.p_delta_rows;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const uint idst = row_dst*ne0 + i0;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + 1]);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + 1]);
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
}

View File

@ -0,0 +1,47 @@
#version 450
#include "rope_head.comp"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
uint ne0 = p.ncols;
uint ne1 = p.p_delta_rows;
uint ne2 = p.ne02;
if (i0 >= ne0) {
return;
}
const uint row_dst = gl_GlobalInvocationID.x;
const uint row_x = row_dst % ne1;
const uint channel_x = row_dst / ne1;
const uint idst = row_dst*ne0 + i0/2;
const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
const int sect_dims = p.sections[0] + p.sections[1];
const int sec_w = p.sections[1] + p.sections[0];
const uint sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < p.sections[0]) {
const uint p0 = sector;
theta_base = data_pos[channel_x]*pow(p.theta_scale, p0);
}
else if (sector >= p.sections[0] && sector < sec_w) {
const uint p0 = sector - p.sections[0];
theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0);
}
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta);
const float x0 = float(data_a[ix + 0]);
const float x1 = float(data_a[ix + p.n_dims]);
data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta);
}

View File

@ -491,6 +491,14 @@ void process_shaders() {
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));