CUDA: enable Gemma FA for HIP/Pascal (llama/9581)

This commit is contained in:
Johannes Gäßler
2024-09-22 09:34:52 +02:00
committed by Georgi Gerganov
parent 008816a257
commit adf2474b10
2 changed files with 11 additions and 11 deletions

View File

@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);