mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-28 21:14:11 +00:00
CUDA: fix crash on large batch size for MoE models (llama/13384)
This commit is contained in:
parent
e27c91f6d6
commit
4b7cbb62ef
@ -10,10 +10,11 @@ static __global__ void k_get_rows(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
|
||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||
|
||||
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
|
||||
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
|
||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const int i10 = blockIdx.x;
|
||||
const int i11 = blockIdx.z / ne12;
|
||||
const int i12 = blockIdx.z % ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
@ -94,8 +96,8 @@ static void get_rows_cuda_q(
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
@ -127,8 +129,8 @@ static void get_rows_cuda_float(
|
||||
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
|
||||
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
|
||||
|
||||
// strides in elements
|
||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||
|
Loading…
x
Reference in New Issue
Block a user