CUDA: enable Gemma FA for HIP/Pascal (llama/9581)

This commit is contained in:
Johannes Gäßler 2024-09-22 09:34:52 +02:00 committed by Georgi Gerganov
parent 008816a257
commit adf2474b10
2 changed files with 11 additions and 11 deletions

View File

@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV: case GGML_OP_RWKV_WKV:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT: {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 128) {
return true;
}
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true; return true;
} }
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && if (op->src[0]->ne[0] == 128) {
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; return true;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) }
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_ADAMW:

View File

@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} }
if (!fast_fp16_available(cc)) { if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8) { if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else { } else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);