mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-09 20:13:14 +00:00
CUDA: fix race condition in MMQ ids_dst (llama/13294)
This commit is contained in:
parent
22ba2e27ce
commit
7564f5e6f1
@ -2636,6 +2636,7 @@ static __global__ void mul_mat_q(
|
|||||||
|
|
||||||
ids_dst_shared[j] = j;
|
ids_dst_shared[j] = j;
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
||||||
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
||||||
@ -2664,6 +2665,7 @@ static __global__ void mul_mat_q(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// __syncthreads(); // There is no previous tile that could cause a race condition.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
||||||
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
||||||
@ -2674,6 +2676,7 @@ static __global__ void mul_mat_q(
|
|||||||
|
|
||||||
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||||
@ -2740,6 +2743,7 @@ static __global__ void mul_mat_q(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
||||||
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
||||||
@ -2750,6 +2754,7 @@ static __global__ void mul_mat_q(
|
|||||||
|
|
||||||
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||||
@ -2805,6 +2810,7 @@ static __global__ void mul_mat_q(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
|
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
|
||||||
|
__syncthreads();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
|
||||||
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
|
||||||
@ -2815,6 +2821,7 @@ static __global__ void mul_mat_q(
|
|||||||
|
|
||||||
ids_dst_shared[j] = j;
|
ids_dst_shared[j] = j;
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
||||||
|
Loading…
x
Reference in New Issue
Block a user