mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-17 22:38:07 +00:00
ggml : full ALiBi support (llama/7192)
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
This commit is contained in:
@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float max_bias;
|
||||
|
||||
#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
|
||||
memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
|
||||
|
||||
#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src2 == nullptr);
|
||||
|
||||
#pragma message("TODO: add ALiBi support")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
|
||||
GGML_ASSERT(max_bias == 0.0f);
|
||||
|
||||
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
|
||||
} break;
|
||||
|
Reference in New Issue
Block a user