mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-24 06:46:37 +00:00
CUDA: enable Gemma FA for HIP/Pascal (llama/9581)
This commit is contained in:
parent
008816a257
commit
adf2474b10
@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
|
||||
#else
|
||||
if (op->src[0]->ne[0] == 128) {
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
if (op->src[0]->ne[0] == 128) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
|
||||
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
|
@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
}
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
if (Q->ne[1] <= 8) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
|
Loading…
Reference in New Issue
Block a user