From 526332873b0a782c9117b15566d9a5ab625e4842 Mon Sep 17 00:00:00 2001 From: Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> Date: Tue, 9 Apr 2024 09:16:13 +0100 Subject: [PATCH] llama : add Command R Plus support (llama/6491) * Add Command R Plus GGUF * Add Command R Plus GGUF * Loading works up to LayerNorm2D * Export new tensors in 1D so they are not quantized. * Fix embedding layer based on Noeda's example * Whitespace * Add line * Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda) * dranger003: Fix block index overflow in CUDA dequantizing. * Reverted blocked multiplication code as it still has issues and could affect other Llama arches * export norms as f32 * fix overflow issues during quant and other cleanup * Type convention Co-authored-by: Georgi Gerganov * dranger003: Fix more int overflow during quant. --------- Co-authored-by: S Co-authored-by: S Co-authored-by: slaren Co-authored-by: Georgi Gerganov --- ggml-cuda.cu | 6 +- ggml-cuda/common.cuh | 2 +- ggml-cuda/convert.cu | 74 +++++----- ggml-cuda/convert.cuh | 2 +- ggml-cuda/dequantize.cuh | 10 +- ggml-cuda/dmmv.cu | 6 +- ggml-cuda/quantize.cu | 16 +- ggml-cuda/quantize.cuh | 2 +- ggml-quants.c | 310 +++++++++++++++++++-------------------- ggml-quants.h | 148 +++++++++---------- ggml.c | 16 +- ggml.h | 14 +- 12 files changed, 303 insertions(+), 303 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ce28cb55..bff8ad9d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas( // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = id == ctx.device ? ne0 : row_diff; + int64_t ldc = id == ctx.device ? ne0 : row_diff; const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat( const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b98d7cbd..481065b2 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); ////////////////////// diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 18a31edc..ed4fa274 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -4,14 +4,14 @@ #define CUDA_Q8_0_NE_ALIGN 2048 template -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { - const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { + const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (i >= k) { return; } - const int ib = i/qk; // block index + const int64_t ib = i/qk; // block index const int iqs = (i%qk)/qr; // quant index const int iybs = i - i%qk; // y block start index const int y_offset = qr == 1 ? 1 : qk/2; @@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ } template -static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { +static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) { #if __CUDA_ARCH__ >= CC_PASCAL constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; @@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h template static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t template static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { - const int i = blockIdx.x; + const int64_t i = blockIdx.x; // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; const int ir = tid%8; - const int ib = 8*i + ir; + const int64_t ib = 8*i + ir; if (ib >= nb32) { return; } @@ -313,14 +313,14 @@ template static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q6_K * x = (const block_q6_K *) vx; - const int i = blockIdx.x; + const int64_t i = blockIdx.x; #if QK_K == 256 // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; + const int64_t tid = threadIdx.x; + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; dst_t * y = yy + i*QK_K + 128*ip + il; @@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t #else // assume 32 threads - const int tid = threadIdx.x; - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 + const int64_t tid = threadIdx.x; + const int64_t ip = tid/16; // 0 or 1 + const int64_t il = tid - 16*ip; // 0...15 dst_t * y = yy + i*QK_K + 16*ip + il; @@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst #endif template -static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { +static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); dequantize_block<<>>(vx, y, k); } -static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) { +static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN; if (k % CUDA_Q8_0_NE_ALIGN == 0) { const bool need_check = false; @@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * } template -static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); @@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); @@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb32 = k / 32; const int nb = (k + 255) / 256; dequantize_block_q4_0<<>>(vx, y, nb32); } template -static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb32 = k / 32; const int nb = (k + 255) / 256; dequantize_block_q4_1<<>>(vx, y, nb32); } template -static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } template -static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); @@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); @@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq2_xxs<<>>(vx, y); } template -static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq2_xs<<>>(vx, y); } template -static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq2_s<<>>(vx, y); } template -static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq3_xxs<<>>(vx, y); } template -static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq3_s<<>>(vx, y); } template -static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq1_s<<>>(vx, y); } template -static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; dequantize_block_iq4_nl<<>>(vx, y); } template -static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_iq1_m<<>>(vx, y); } template -static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; #if QK_K == 64 dequantize_block_iq4_nl<<>>(vx, y); @@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, } template -static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; @@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res } template -static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { +static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; convert_unary<<>>(vx, y, k); } diff --git a/ggml-cuda/convert.cuh b/ggml-cuda/convert.cuh index db34c0be..5394be9f 100644 --- a/ggml-cuda/convert.cuh +++ b/ggml-cuda/convert.cuh @@ -3,7 +3,7 @@ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 template -using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); +using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream); typedef to_t_cuda_t to_fp32_cuda_t; typedef to_t_cuda_t to_fp16_cuda_t; diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh index b5440063..bd3c2d9d 100644 --- a/ggml-cuda/dequantize.cuh +++ b/ggml-cuda/dequantize.cuh @@ -1,6 +1,6 @@ #include "common.cuh" -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; const dfloat d = x[ib].d; @@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; const dfloat d = x[ib].d; @@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; const dfloat d = x[ib].d; diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 0b17e3cb..7313e3e1 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ +static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; // automatic half -> float type cast if dfloat == float @@ -577,7 +577,7 @@ template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block - const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { return; @@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons for (int i = 0; i < ncols; i += iter_stride) { const int col = i + vals_per_iter*tid; - const int ib = (row*ncols + col)/qk; // x block index + const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index diff --git a/ggml-cuda/quantize.cu b/ggml-cuda/quantize.cu index a1fbc993..7578c4b6 100644 --- a/ggml-cuda/quantize.cu +++ b/ggml-cuda/quantize.cu @@ -1,20 +1,20 @@ #include "quantize.cuh" -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { - const int ix = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) { + const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (ix >= kx_padded) { return; } - const int iy = blockDim.y*blockIdx.y + threadIdx.y; + const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y; - const int i_padded = iy*kx_padded + ix; + const int64_t i_padded = (int64_t)iy*kx_padded + ix; block_q8_1 * y = (block_q8_1 *) vy; - const int ib = i_padded / QK8_1; // block index - const int iqs = i_padded % QK8_1; // quant index + const int64_t ib = i_padded / QK8_1; // block index + const int64_t iqs = i_padded % QK8_1; // quant index const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; float amax = fabsf(xi); @@ -36,8 +36,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { - const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; +void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) { + const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ky, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); quantize_q8_1<<>>(x, vy, kx, kx_padded); diff --git a/ggml-cuda/quantize.cuh b/ggml-cuda/quantize.cuh index adb89c83..b37a4752 100644 --- a/ggml-cuda/quantize.cuh +++ b/ggml-cuda/quantize.cuh @@ -2,4 +2,4 @@ #define CUDA_QUANTIZE_BLOCK_SIZE 256 -void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream); +void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream); diff --git a/ggml-quants.c b/ggml-quants.c index f2e6c4bd..32e84434 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -544,7 +544,7 @@ static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -581,12 +581,12 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } } -void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_0_reference(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -623,11 +623,11 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } } -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_1_reference(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -671,11 +671,11 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } } -void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_0_reference(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -719,12 +719,12 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } } -void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q5_1_reference(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -749,7 +749,7 @@ void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict } } -void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -938,7 +938,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -973,7 +973,7 @@ void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict } } -void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -1192,7 +1192,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { #endif } -void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -1212,7 +1212,7 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int } } -void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_1; assert(k % qk == 0); @@ -1233,7 +1233,7 @@ void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int } } -void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -1259,7 +1259,7 @@ void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int } } -void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK5_1; assert(k % qk == 0); @@ -1286,7 +1286,7 @@ void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int } } -void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) { +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK8_0; assert(k % qk == 0); @@ -1581,7 +1581,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1658,7 +1658,7 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1704,7 +1704,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int } } -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q2_K_reference(x, vy, k); } @@ -1960,14 +1960,14 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri } } -size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { - quantize_row_q2_K_reference(src, dst, nrow*n_per_row); + quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -1978,7 +1978,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int nrow, //========================= 3-bit (de)-quantization -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2092,7 +2092,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } #if QK_K == 256 -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2142,7 +2142,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } } #else -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); assert(QK_K == 64); const int nb = k / QK_K; @@ -2175,11 +2175,11 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } #endif -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q3_K_reference(x, vy, k); } -static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int n_per_row, const float * restrict quant_weights) { +static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q3_K_reference(x, y, n_per_row); @@ -2268,14 +2268,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri #endif } -size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { - quantize_row_q3_K_reference(src, dst, nrow*n_per_row); + quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -2286,7 +2286,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int nrow, // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2393,7 +2393,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2432,19 +2432,19 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int } } -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; quantize_row_q4_K_reference(x, y, k); } -static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q4_K_reference(x, y, n_per_row); #else assert(n_per_row % QK_K == 0); - const int nb = n_per_row / QK_K; + const int64_t nb = n_per_row / QK_K; uint8_t L[QK_K]; uint8_t Laux[32]; @@ -2516,14 +2516,14 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri #endif } -size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { - quantize_row_q4_K_reference(src, dst, nrow*n_per_row); + quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -2534,9 +2534,9 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int nrow, // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; #if QK_K == 256 uint8_t L[QK_K]; @@ -2676,9 +2676,9 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict } } -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2721,19 +2721,19 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int } } -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; quantize_row_q5_K_reference(x, y, k); } -static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q5_K_reference(x, y, n_per_row); #else assert(n_per_row % QK_K == 0); - const int nb = n_per_row / QK_K; + const int64_t nb = n_per_row / QK_K; uint8_t L[QK_K]; uint8_t Laux[32]; @@ -2825,14 +2825,14 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri #endif } -size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { - quantize_row_q5_K_reference(src, dst, nrow*n_per_row); + quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -2843,9 +2843,9 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int nrow, // ====================== 6-bit (de)-quantization -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; int8_t L[QK_K]; float scales[QK_K/16]; @@ -2925,9 +2925,9 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict } } -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -2972,19 +2972,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int } } -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; quantize_row_q6_K_reference(x, y, k); } -static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { #if QK_K != 256 (void)quant_weights; quantize_row_q6_K_reference(x, y, n_per_row); #else assert(n_per_row % QK_K == 0); - const int nb = n_per_row / QK_K; + const int64_t nb = n_per_row / QK_K; int8_t L[QK_K]; float scales[QK_K/16]; @@ -3067,14 +3067,14 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri #endif } -size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { - quantize_row_q6_K_reference(src, dst, nrow*n_per_row); + quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3083,7 +3083,7 @@ size_t quantize_q6_K(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { @@ -3098,7 +3098,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK4_0; + const int64_t nb = n_per_row/QK4_0; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK4_0 * ib; const float * qw = quant_weights + QK4_0 * ib; @@ -3111,14 +3111,14 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri } } -size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_reference(src, dst, nrow*n_per_row); + quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3126,7 +3126,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK4_1 == 32, "QK4_1 must be 32"); if (!quant_weights) { @@ -3141,7 +3141,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK4_1; + const int64_t nb = n_per_row/QK4_1; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK4_1 * ib; const float * qw = quant_weights + QK4_1 * ib; @@ -3156,14 +3156,14 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri } } -size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_1_reference(src, dst, nrow*n_per_row); + quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3171,7 +3171,7 @@ size_t quantize_q4_1(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK5_0 == 32, "QK5_0 must be 32"); if (!quant_weights) { @@ -3186,7 +3186,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK5_0; + const int64_t nb = n_per_row/QK5_0; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK5_0 * ib; const float * qw = quant_weights + QK5_0 * ib; @@ -3210,14 +3210,14 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri } } -size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_0_reference(src, dst, nrow*n_per_row); + quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3225,7 +3225,7 @@ size_t quantize_q5_0(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int n_per_row, const float * quant_weights) { +static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK5_1 == 32, "QK5_1 must be 32"); if (!quant_weights) { @@ -3240,7 +3240,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; float sigma2 = sum_x2/n_per_row; - const int nb = n_per_row/QK5_1; + const int64_t nb = n_per_row/QK5_1; for (int ib = 0; ib < nb; ++ib) { const float * xb = x + QK5_1 * ib; const float * qw = quant_weights + QK5_1 * ib; @@ -3263,14 +3263,14 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri } } -size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_1_reference(src, dst, nrow*n_per_row); + quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights); src += n_per_row; qrow += row_size; @@ -3278,18 +3278,18 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int nrow, return nrow * row_size; } -size_t quantize_q8_0(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - quantize_row_q8_0_reference(src, dst, nrow*n_per_row); + quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } // ====================== "True" 2-bit (de)-quantization -void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) { +void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; uint32_t aux32[2]; const uint8_t * aux8 = (const uint8_t *)aux32; @@ -3315,9 +3315,9 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y // ====================== 2.3125 bpw (de)-quantization -void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) { +void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; float db[2]; @@ -3342,9 +3342,9 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, // ====================== 2.5625 bpw (de)-quantization -void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int k) { +void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; float db[2]; @@ -3374,9 +3374,9 @@ void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, in // ====================== 3.0625 bpw (de)-quantization -void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int k) { +void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; uint32_t aux32; @@ -3406,9 +3406,9 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y // ====================== 3.3125 bpw (de)-quantization -void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) { +void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3449,9 +3449,9 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in // ====================== 1.5625 bpw (de)-quantization -void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) { +void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3474,9 +3474,9 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in } } -void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int k) { +void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; float delta[4]; uint16_t idx[4]; @@ -3535,9 +3535,9 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; -void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) { +void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) { assert(k % QK4_NL == 0); - const int nb = k / QK4_NL; + const int64_t nb = k / QK4_NL; for (int i = 0; i < nb; i++) { @@ -3553,12 +3553,12 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, } } -void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int k) { +void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); #if QK_K == 64 dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k); #else - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3582,9 +3582,9 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, //===================================== Q8_K ============================================== -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3621,9 +3621,9 @@ void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict } } -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); - const int nb = k / QK_K; + const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { for (int j = 0; j < QK_K; ++j) { @@ -3632,7 +3632,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int } } -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { +void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q8_K_reference(x, y, k); } @@ -10648,7 +10648,7 @@ static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } -static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS); @@ -10664,7 +10664,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict const int kMaxQ = 3; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq2_xxs * y = vy; @@ -10821,7 +10821,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict } } -static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS); @@ -10837,7 +10837,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v const int kMaxQ = 3; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq2_xs * y = vy; @@ -11001,11 +11001,11 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v } } -size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq2_xxs); @@ -11013,11 +11013,11 @@ size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int nro return nrow * nblock * sizeof(block_iq2_xxs); } -size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq2_xs); @@ -11242,7 +11242,7 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } -static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n, +static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq3_data_index(grid_size); @@ -11259,7 +11259,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v const int kMaxQ = 8; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; ggml_fp16_t * dh; uint8_t * qs; @@ -11455,11 +11455,11 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v } } -size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq3_xxs); @@ -11467,13 +11467,13 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int nro return nrow * nblock * sizeof(block_iq3_xxs); } -void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_xxs * restrict y = vy; quantize_row_iq3_xxs_reference(x, y, k); } -void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) { +void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); } @@ -11504,7 +11504,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo const int kMaxQ = 8; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq3_s * y = vy; @@ -11661,9 +11661,9 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo } #define IQ3S_BLOCK_SIZE 32 -size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; float scales[QK_K/IQ3S_BLOCK_SIZE]; float weight[IQ3S_BLOCK_SIZE]; float xval[IQ3S_BLOCK_SIZE]; @@ -11674,7 +11674,7 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, bool is_on_grid_aux[IQ3S_BLOCK_SIZE/4]; uint8_t block_signs[IQ3S_BLOCK_SIZE/8]; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights, scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs); src += n_per_row; @@ -11683,13 +11683,13 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int nrow, return nrow * nblock * sizeof(block_iq3_s); } -void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_s * restrict y = vy; quantize_row_iq3_s_reference(x, y, k); } -void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int k) { +void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); } @@ -11822,7 +11822,7 @@ static int iq1_sort_helper(const void * left, const void * right) { #define IQ1S_BLOCK_SIZE 32 #define IQ1M_BLOCK_SIZE 16 -static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights, +static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, float * scales, float * weight, float * sumx, @@ -11846,7 +11846,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy block_iq1_s * y = vy; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; const int block_size = IQ1S_BLOCK_SIZE; @@ -11980,7 +11980,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy } } -size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); float scales[QK_K/IQ1S_BLOCK_SIZE]; float weight[IQ1S_BLOCK_SIZE]; @@ -11990,9 +11990,9 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, float pairs[2*IQ1S_BLOCK_SIZE]; uint16_t index[IQ1S_BLOCK_SIZE/8]; int8_t shifts[QK_K/IQ1S_BLOCK_SIZE]; - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts); src += n_per_row; qrow += nblock*sizeof(block_iq1_s); @@ -12000,7 +12000,7 @@ size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int nrow, return nrow * nblock * sizeof(block_iq1_s); } -static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights, +static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights, float * scales, float * weight, float * pairs, @@ -12022,7 +12022,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy block_iq1_m * y = vy; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; const int block_size = IQ1M_BLOCK_SIZE; @@ -12265,7 +12265,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy } } -size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); float scales[QK_K/IQ1M_BLOCK_SIZE]; float weight[IQ1M_BLOCK_SIZE]; @@ -12273,9 +12273,9 @@ size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int nrow, float pairs[2*IQ1M_BLOCK_SIZE]; uint16_t index[IQ1M_BLOCK_SIZE/8]; int8_t shifts[QK_K/IQ1M_BLOCK_SIZE]; - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts); src += n_per_row; qrow += nblock*sizeof(block_iq1_m); @@ -12407,16 +12407,16 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block } } -size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK4_NL == 0); - int nblock = n_per_row/QK4_NL; + int64_t nblock = n_per_row/QK4_NL; char * qrow = (char *)dst; uint8_t L[QK4_NL]; float weight[QK4_NL]; uint16_t unused_h; uint8_t * unused_l = NULL; float scale; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { block_iq4_nl * iq4 = (block_iq4_nl *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; @@ -12429,9 +12429,9 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow return nrow * nblock * sizeof(block_iq4_nl); } -void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) { GGML_ASSERT(k%QK4_NL == 0); - int nblock = k/QK4_NL; + int64_t nblock = k/QK4_NL; uint8_t L[QK4_NL]; float weight[QK4_NL]; uint16_t unused_h; @@ -12444,22 +12444,22 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) { } } -void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) { +void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { assert(k % QK4_NL == 0); quantize_row_iq4_nl(x, y, k); } -size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { #if QK_K == 64 return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights); #else GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; uint8_t L[QK_K]; float weight[32]; float scales[QK_K/32]; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { block_iq4_xs * iq4 = (block_iq4_xs *)qrow; for (int ibl = 0; ibl < nblock; ++ibl) { const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL; @@ -12473,20 +12473,20 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow #endif } -void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq4_xs * restrict y = vy; quantize_row_iq4_xs_reference(x, y, k); } -void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int k) { +void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); } // =============================== 2.5625 bpw -static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { +static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) { const int gindex = iq2_data_index(GGML_TYPE_IQ2_S); @@ -12501,7 +12501,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy const int kMaxQ = 3; - const int nbl = n/QK_K; + const int64_t nbl = n/QK_K; block_iq2_s * y = vy; @@ -12654,11 +12654,11 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy } } -size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) { +size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); - int nblock = n_per_row/QK_K; + int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; - for (int row = 0; row < nrow; ++row) { + for (int64_t row = 0; row < nrow; ++row) { quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq2_s); @@ -12666,12 +12666,12 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int nrow, return nrow * nblock * sizeof(block_iq2_s); } -void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int k) { +void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_s(x, y, 1, k, NULL); } -void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int k) { +void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq2_s * restrict y = vy; quantize_row_iq2_s_reference(x, y, k); diff --git a/ggml-quants.h b/ggml-quants.h index ac1091c3..4d436a8f 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -12,70 +12,70 @@ extern "C" { #endif // Quantization -void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int k); -void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int k); -void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int k); -void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int k); -void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int k); -void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int k); +void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k); -void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int k); -void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int k); -void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int k); -void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k); -void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k); +void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k); -void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k); -void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int k); -void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int k); -void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int k); +void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); -void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); -void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -101,26 +101,26 @@ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); +size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); -size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int nrows, int n_per_row, const float * imatrix); +size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); void iq2xs_init_impl(enum ggml_type type); void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml.c b/ggml.c index c9b0a6a0..793b67f4 100644 --- a/ggml.c +++ b/ggml.c @@ -338,14 +338,14 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { return GGML_FP32_TO_FP16(x); } -void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) { - for (int i = 0; i < n; i++) { +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { + for (int64_t i = 0; i < n; i++) { y[i] = GGML_FP16_TO_FP32(x[i]); } } -void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) { - int i = 0; +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { + int64_t i = 0; #if defined(__F16C__) for (; i + 7 < n; i += 8) { __m256 x_vec = _mm256_loadu_ps(x + i); @@ -20331,11 +20331,11 @@ size_t ggml_quantize_chunk( enum ggml_type type, const float * src, void * dst, - int start, - int nrows, - int n_per_row, + int64_t start, + int64_t nrows, + int64_t n_per_row, const float * imatrix) { - const int n = nrows * n_per_row; + const int64_t n = (int64_t) nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); diff --git a/ggml.h b/ggml.h index 5cef45c0..abe3767f 100644 --- a/ggml.h +++ b/ggml.h @@ -332,8 +332,8 @@ extern "C" { GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); - GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n); - GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n); struct ggml_object; struct ggml_context; @@ -2210,9 +2210,9 @@ extern "C" { enum ggml_type type, const float * src, void * dst, - int start, - int nrows, - int n_per_row, + int64_t start, + int64_t nrows, + int64_t n_per_row, const float * imatrix); // @@ -2377,8 +2377,8 @@ extern "C" { #else #define GGML_RESTRICT restrict #endif - typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); - typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, const void * GGML_RESTRICT y, size_t by, int nrc);