mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-25 03:24:13 +00:00
CUDA: fix race conditions FlashAttention kernels (llama/13438)
This commit is contained in:
parent
16f3546f38
commit
6db0e01db6
@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Write back combined meta data:
|
// Write back combined meta data:
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
||||||
|
@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
KQ[j*D + tid] = -HALF_MAX_HALF;
|
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user