kompute : improve backend to pass test_backend_ops (llama/10542)

* kompute: op_unary: reject unsupported parameters

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: softmax: implement ALiBi support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: rope: implement neox and phi3 support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_q4_k permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_[q4_0|q4_1|q8_0] permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_f16 permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

* kompute: op_mul_mat_q6_k permutted support

Signed-off-by: Sergio Lopez <slp@redhat.com>

---------

Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
Sergio López 2024-11-28 12:51:38 +01:00 committed by Georgi Gerganov
parent 90dd5fca9c
commit 42099a9342
16 changed files with 403 additions and 233 deletions

View File

@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
kompute-shaders/op_getrows_q4_0.comp
kompute-shaders/op_getrows_q4_1.comp
kompute-shaders/op_getrows_q6_k.comp
kompute-shaders/op_rope_f16.comp
kompute-shaders/op_rope_f32.comp
kompute-shaders/op_rope_norm_f16.comp
kompute-shaders/op_rope_norm_f32.comp
kompute-shaders/op_rope_neox_f16.comp
kompute-shaders/op_rope_neox_f32.comp
kompute-shaders/op_cpy_f16_f16.comp
kompute-shaders/op_cpy_f16_f32.comp
kompute-shaders/op_cpy_f32_f16.comp
@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h
shaderop_getrows_q6_k.h
shaderop_rope_f16.h
shaderop_rope_f32.h
shaderop_rope_norm_f16.h
shaderop_rope_norm_f32.h
shaderop_rope_neox_f16.h
shaderop_rope_neox_f32.h
shaderop_cpy_f16_f16.h
shaderop_cpy_f16_f32.h
shaderop_cpy_f32_f16.h

View File

@ -28,8 +28,10 @@
#include "shaderop_getrows_q4_0.h"
#include "shaderop_getrows_q4_1.h"
#include "shaderop_getrows_q6_k.h"
#include "shaderop_rope_f16.h"
#include "shaderop_rope_f32.h"
#include "shaderop_rope_norm_f16.h"
#include "shaderop_rope_norm_f32.h"
#include "shaderop_rope_neox_f16.h"
#include "shaderop_rope_neox_f32.h"
#include "shaderop_cpy_f16_f16.h"
#include "shaderop_cpy_f16_f32.h"
#include "shaderop_cpy_f32_f16.h"
@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
vk::DescriptorPoolSize(
vk::DescriptorType::eStorageBuffer,
3 * size // Descriptor count is number of possible tensors to pass into an algorithm
4 * size // Descriptor count is number of possible tensors to pass into an algorithm
)
};
@ -788,7 +790,8 @@ static void ggml_vk_soft_max(
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
float scale
float scale, float max_bias, float m0, float m1,
uint32_t n_head_log2
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
kp::shader_data::op_softmax_comp_spv_len);
@ -796,12 +799,14 @@ static void ggml_vk_soft_max(
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne01, ne02;
float scale;
float scale, max_bias, m0, m1;
uint32_t n_head_log2;
int32_t mask;
} pushConsts {
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
ne00, ne01, ne02,
scale,
scale, max_bias, m0, m1,
n_head_log2,
bool(inB)
};
@ -911,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02,
uint32_t nb00, uint32_t nb01, uint32_t nb02,
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
uint32_t nb10, uint32_t nb11, uint32_t nb12,
uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
int32_t ne0, int32_t ne1,
uint32_t r2, uint32_t r3
) {
@ -923,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne01, ne02;
uint32_t nb00, nb01, nb02;
uint32_t nb00, nb01, nb02, nb03;
int32_t ne10, ne11, ne12;
uint32_t nb10, nb11, nb12;
uint32_t nb10, nb11, nb12, nb13;
int32_t ne0, ne1;
uint32_t r2, r3;
} pushConsts {
safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
ne00, ne01, ne02,
nb00, nb01, nb02,
nb00, nb01, nb02, nb03,
ne10, ne11, ne12,
nb10, nb11, nb12,
nb10, nb11, nb12, nb13,
ne0, ne1,
r2, r3
};
@ -1013,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
int32_t ne00, int32_t ne01, int32_t ne02,
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
int32_t ne0, int32_t ne1,
uint32_t nb01, uint32_t nb02, uint32_t nb03,
uint32_t nb11, uint32_t nb12, uint32_t nb13,
uint32_t r2, uint32_t r3
) {
struct PushConstants {
@ -1020,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
int32_t ne00, ne01, ne02;
int32_t ne10, ne12;
int32_t ne0, ne1;
uint32_t nb01, nb02, nb03;
uint32_t nb11, nb12, nb13;
uint32_t r2, r3;
} pushConsts {
safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
ne00, ne01, ne02,
ne10, ne12,
ne0, ne1,
nb01, nb02, nb03,
nb11, nb12, nb13,
r2, r3
};
auto name = std::string(__func__) + "_" + suffix;
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(name)) {
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(name);
@ -1074,19 +1085,26 @@ static void ggml_vk_mul_mat_q4_k(
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
int32_t ne1, int32_t r2, int32_t r3
int32_t ne00, int32_t ne01, int32_t ne02,
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
int32_t ne0, int32_t ne1,
uint32_t nb01, uint32_t nb02, uint32_t nb03,
uint32_t nb11, uint32_t nb12, uint32_t nb13,
uint32_t r2, uint32_t r3
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
uint32_t r2, r3;
} pushConsts {
0, 0, 0,
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
nb01, nb02, nb03, nb11, nb12, nb13,
r2, r3
};
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
@ -1108,28 +1126,37 @@ static void ggml_vk_mul_mat_q6_k(
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
int32_t ne00, int32_t ne01, int32_t ne02,
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
int32_t ne0, int32_t ne1,
uint32_t nb01, uint32_t nb02, uint32_t nb03,
uint32_t nb11, uint32_t nb12, uint32_t nb13,
uint32_t r2, uint32_t r3
) {
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t ne00, ne10, ne0, ne1, ne01, gqa;
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
uint32_t r2, r3;
} pushConsts {
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
ne00, ne10, ne0, ne1, ne01, ne12/ne02
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
nb01, nb02, nb03, nb11, nb12, nb13,
r2, r3
};
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) {
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
const uint32_t local_x = 2;
const uint32_t local_y = ggml_vk_current_device().subgroupSize;
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
} else {
s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get());
}
@ -1217,10 +1244,11 @@ static void ggml_vk_rope(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& inC,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
int32_t ne01, int32_t ne02, int32_t ne03,
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
int32_t ne0,
@ -1228,11 +1256,17 @@ static void ggml_vk_rope(
) {
GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
static const auto spirv_f16 = getSpirvShader(
kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
static const auto spirv_norm_f16 = getSpirvShader(
kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
);
static const auto spirv_f32 = getSpirvShader(
kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
static const auto spirv_norm_f32 = getSpirvShader(
kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
);
static const auto spirv_neox_f16 = getSpirvShader(
kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
);
static const auto spirv_neox_f32 = getSpirvShader(
kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
);
int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
@ -1247,32 +1281,40 @@ static void ggml_vk_rope(
GGML_ASSERT(nb0 % type_size == 0);
struct PushConstants {
uint32_t inAOff, inBOff, outOff;
uint32_t inAOff, inBOff, inCOff, outOff;
int32_t n_dims, mode, n_ctx_orig;
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base, freq_scale;
bool has_freq_factors;
float ext_factor, attn_factor, beta_fast, beta_slow;
uint32_t nb00, nb01, nb02, nb03;
int32_t ne0;
uint32_t nb0, nb1, nb2, nb3;
} pushConsts {
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
n_dims, mode, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
freq_base, freq_scale,
has_freq_factors,
ext_factor, attn_factor, beta_fast, beta_slow,
nb00, nb01, nb02, nb03,
ne0,
nb0, nb1, nb2, nb3
};
auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
auto & inC_ = inC ? inC : inA;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_f16 = src0t == GGML_TYPE_F16;
auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(name)) {
auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
s_algo = komputeManager()->algorithm<float, PushConstants>(
name, s_kompute_context->pool.get(), {inA, inB, out},
src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
{unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
);
} else {
s_algo = komputeManager()->getAlgorithm(name);
s_algo->setTensors({inA, inB, out});
s_algo->setTensors({inA, inB, inC_, out});
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
s_algo->setPushConstants<PushConstants>({pushConsts});
s_algo->updateDescriptors(s_kompute_context->pool.get());
@ -1351,11 +1393,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
}
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
int64_t n = ggml_nelements(op);
switch (op->op) {
case GGML_OP_UNARY:
if (n % 4 != 0) return false;
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU:
if (n % 8 != 0) return false;
// fall through
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SILU:
return ggml_is_contiguous(op->src[0]);
default:
@ -1413,8 +1459,8 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_Q6_K:
return op->ne[3] == 1;
case GGML_TYPE_Q6_K:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
@ -1515,9 +1561,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
uint32_t off_src0 = 0;
uint32_t off_src1 = 0;
uint32_t off_src2 = 0;
uint32_t off_dst = 0;
const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
switch (dst->op) {
@ -1593,11 +1641,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
#pragma message("TODO: add ALiBi support")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
GGML_ASSERT(max_bias == 0.0f);
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
} break;
case GGML_OP_DIAG_MASK_INF:
{
@ -1649,38 +1702,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
case GGML_TYPE_F16:
ggml_vk_mul_mat_f16(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
ne0, ne1, r2, r3
);
break;
case GGML_TYPE_Q8_0:
ggml_vk_mul_mat_q8_0(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
);
break;
case GGML_TYPE_Q4_0:
ggml_vk_mul_mat_q4_0(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
);
break;
case GGML_TYPE_Q4_1:
ggml_vk_mul_mat_q4_1(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
);
break;
case GGML_TYPE_Q4_K:
ggml_vk_mul_mat_q4_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
);
break;
case GGML_TYPE_Q6_K:
ggml_vk_mul_mat_q6_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
);
break;
default: {
@ -1709,13 +1768,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
} break;
case GGML_OP_ROPE:
{
#pragma message("TODO: implement phi3 frequency factors support")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
#pragma message("TODO: update rope NORM mode to match NEOX mode")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];
@ -1724,6 +1776,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const bool has_freq_factors = dst->src[2] != nullptr;
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
@ -1732,8 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
ggml_vk_rope(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
);
} break;

View File

@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
#extension GL_EXT_control_flow_attributes: enable
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable

View File

@ -20,12 +20,14 @@ layout (push_constant) uniform parameter {
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne10;
int ne11;
int ne12;
uint nb10;
uint nb11;
uint nb12;
uint nb13;
int ne0;
int ne1;
uint r2;
@ -42,7 +44,7 @@ void main() {
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
@ -52,7 +54,7 @@ void main() {
break;
}
const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf = 0;
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {

View File

@ -24,8 +24,14 @@ layout (push_constant) uniform parameter {
int ne01;
int ne02;
int ne12;
int r2;
int r3;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
void main() {
@ -50,10 +56,11 @@ void main() {
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;
const uint xblk = ib_row + offset0 + pcs.inAOff;
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
const uint xblk = offset0 + pcs.inAOff;
const uint y = (offset1 / 4) + pcs.inBOff;
float yl[16];
float yh[16];
@ -74,7 +81,7 @@ void main() {
}
for (int row = 0; row < N_DST; row++) {
uint row_idx = row * nb;
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);

View File

@ -21,7 +21,16 @@ layout (push_constant) uniform parameter {
int ne0;
int ne1;
int ne01;
int gqa;
int ne02;
int ne12;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
void main() {
@ -34,12 +43,15 @@ void main() {
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint r2 = gl_WorkGroupID.z;
const uint im = gl_WorkGroupID.z;
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
const uint x = row * nb + offset0; // Based from inA without base offset
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf = 0;
@ -89,6 +101,6 @@ void main() {
const float tot = subgroupAdd(sumf);
if (subgroupElect()) {
out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
}
}

View File

@ -14,10 +14,15 @@ void main() {
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
// pointers to src0 rows
uint ax[N_ROWS];
for (int row = 0; row < N_ROWS; ++row) {
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint x = offset0; // Based from inA without base offset
const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
ax[row] = offset0 + pcs.inAOff;
}
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
@ -32,8 +37,7 @@ void main() {
for (uint ib = ix; ib < nb; ib += 16) {
for (int row = 0; row < N_ROWS; row++) {
const uint block_index = x + ib + row * nb;
sumf[row] += block_q_n_dot_y(block_index, yb, il);
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
}
yb += BLOCKS_IN_QUANT * 16;

View File

@ -1,5 +1,5 @@
layout(local_size_x_id = 0) in;
layout(local_size_y = 1) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
@ -17,6 +17,12 @@ layout (push_constant) uniform parameter {
int ne12;
int ne0;
int ne1;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;

View File

@ -1,73 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
const int p = inB[pcs.inBOff + i2];
float theta = float(p);
if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
}
} else {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const uint cur_rot = ic;
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint i0 = ic/2;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
}
for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
const uint i0 = ic;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data + 0] = inA[src + 0];
out_[dst_data + 1] = inA[src + 1];
}
}
}

View File

@ -1,73 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
const int p = inB[pcs.inBOff + i2];
float theta = float(p);
if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
}
} else {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const uint cur_rot = ic;
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
theta *= theta_scale;
const uint i0 = ic/2;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
}
for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
const uint i0 = ic;
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data + 0] = inA[src + 0];
out_[dst_data + 1] = inA[src + 1];
}
}
}

View File

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -0,0 +1,52 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}

View File

@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants {
int ne01;
int ne02;
float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
int mask;
} pcs;
@ -34,17 +38,29 @@ void main() {
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
const uint pdst = extra_off + pcs.outOff; // Based from out_
float slope = 1.0f;
// ALiBi
if (pcs.max_bias > 0.0f) {
int64_t h = i02;
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
slope = pow(base, float(exp));
}
// parallel max
float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
}
float max_ = subgroupMax(localMax);
// parallel sum
float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0;
}

View File

@ -8,12 +8,14 @@ layout(local_size_x = 1) in;
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint inCOff;
uint outOff;
int n_dims;
int mode;
int n_ctx_orig;
float freq_base;
float freq_scale;
bool has_freq_factors;
float ext_factor;
float attn_factor;
float beta_fast;