From 3c63f4cf35a26eede87fb60338c8e0aa23d6d8f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 22 May 2024 17:58:25 +0200 Subject: [PATCH] CUDA: fix FA out-of-bounds writes (llama/7465) --- ggml-cuda/fattn-tile-f16.cu | 4 ++++ ggml-cuda/fattn-tile-f32.cu | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 4a07ac6a..586d469c 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum(kqsum_j); diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index b8b2f69e..b6ef8eb4 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j);