From 6db0e01db69f9cc1bc7d5b17f18fad3eb672eed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 10 May 2025 22:22:48 +0200 Subject: [PATCH] CUDA: fix race conditions FlashAttention kernels (llama/13438) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 ++ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 1 + 2 files changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2f95fa3..9873ea75 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index ef0addc1..d96e3921 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { KQ[j*D + tid] = -HALF_MAX_HALF; } + __syncthreads(); half2 VKQ[ncols] = {{0.0f, 0.0f}};