mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-16 07:10:22 +00:00
CUDA: fix Volta FlashAttention logic (llama/11615)
This commit is contained in:
parent
fad2806352
commit
dbeb7916b8
@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|||||||
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
// case 256:
|
// case 256:
|
||||||
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
||||||
// break;
|
// break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!new_mma_available(cc)) {
|
if (!fp16_mma_available(cc)) {
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||||
@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||||
if (cc == GGML_CUDA_CC_VOLTA) {
|
if (cc == GGML_CUDA_CC_VOLTA) {
|
||||||
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user