CUDA: fix FA out-of-bounds writes (llama/7465)

This commit is contained in:
Johannes Gäßler 2024-05-22 17:58:25 +02:00 committed by Georgi Gerganov
parent 5848dfd9c8
commit 3c63f4cf35
2 changed files with 8 additions and 0 deletions

View File

@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum(kqsum_j);

View File

@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);