mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-29 15:30:03 +00:00
metal : fix floating-point range of attention scores in FA kernels (llama/13090)
ggml-ci
This commit is contained in:
parent
cf3eb291ab
commit
01e1600edd
@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
{
|
{
|
||||||
float S[Q] = { [0 ... Q-1] = 0.0f };
|
float S[Q] = { [0 ... Q-1] = 0.0f };
|
||||||
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
|
||||||
|
|
||||||
// thread indices inside the simdgroup
|
// thread indices inside the simdgroup
|
||||||
// TODO: see if we can utilize quad-group functions for better performance
|
// TODO: see if we can utilize quad-group functions for better performance
|
||||||
@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (ushort sg = 1; sg < nsg; ++sg) {
|
for (ushort sg = 1; sg < nsg; ++sg) {
|
||||||
float S = { 0.0f };
|
float S = { 0.0f };
|
||||||
float M = { -__FLT16_MAX__/2 };
|
float M = { -__FLT_MAX__/2 };
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
|
|
||||||
{
|
{
|
||||||
float S = 0.0f;
|
float S = 0.0f;
|
||||||
float M = -__FLT16_MAX__/2;
|
float M = -__FLT_MAX__/2;
|
||||||
|
|
||||||
// thread indices inside the simdgroup
|
// thread indices inside the simdgroup
|
||||||
const short tx = tiisg%NL;
|
const short tx = tiisg%NL;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user