metal : speed-up KQ multiplication

This commit is contained in:
Georgi Gerganov 2023-09-13 19:59:16 +03:00
parent 16db4da3f1
commit 8e8daa8451
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 4 additions and 5 deletions

View File

@ -856,8 +856,7 @@ void ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
if (ggml_is_contiguous(src1) &&
src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 &&

View File

@ -1688,7 +1688,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, ggml_cont(ctx0, Q));
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
@ -2089,7 +2089,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, ggml_cont(ctx0, Q));
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
@ -2184,7 +2184,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, ggml_cont(ctx0, Q));
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0,