CPU/CUDA: Gemma 2 FlashAttention support (llama/8542)

* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check
This commit is contained in:
Johannes Gäßler
2024-08-24 21:34:59 +02:00
committed by Georgi Gerganov
parent 9b16ddd3a5
commit 24d8534bd8
10 changed files with 304 additions and 63 deletions

View File

@ -7250,7 +7250,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
float max_bias) {
float max_bias,
float logit_softcap) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
@ -7277,7 +7278,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
float params[] = { scale, max_bias };
float params[] = { scale, max_bias, logit_softcap };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_FLASH_ATTN_EXT;
@ -7297,7 +7298,7 @@ void ggml_flash_attn_ext_set_prec(
const int32_t prec_i32 = (int32_t) prec;
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
}
// ggml_flash_attn_back
@ -15745,11 +15746,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float scale = 1.0f;
float max_bias = 0.0f;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
@ -15813,7 +15820,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
s = s*scale + mv; // scale KQ value and apply mask
s = s*scale; // scale KQ value
if (logit_softcap != 0.0f) {
s = logit_softcap*tanhf(s);
}
s += mv; // apply mask
const float Mold = M;
@ -15822,7 +15835,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
if (v->type== GGML_TYPE_F16) {
if (v->type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
@ -15889,7 +15902,7 @@ static void ggml_compute_forward_flash_attn_ext(
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
switch (dst->op_params[2]) {
switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{