From f9015b585b41316f1eaed1668cf522915254cc01 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Sun, 30 Mar 2025 16:59:38 +0800 Subject: [PATCH] musa: fix all warnings, re-enable `-DLLAMA_FATAL_WARNINGS=ON` in ci and update doc (llama/12611) * musa: fix all warnings Signed-off-by: Xiaodong Ye * musa: enable -DLLAMA_FATAL_WARNINGS=ON in run.sh Signed-off-by: Xiaodong Ye * musa: update ci doc (install ccache) Signed-off-by: Xiaodong Ye * fix Windows build issue Signed-off-by: Xiaodong Ye * Address review comments Signed-off-by: Xiaodong Ye * Address review comments Signed-off-by: Xiaodong Ye --------- Signed-off-by: Xiaodong Ye --- ggml/src/ggml-common.h | 18 ++++-- ggml/src/ggml-cuda/common.cuh | 4 ++ ggml/src/ggml-cuda/concat.cu | 4 +- ggml/src/ggml-cuda/conv-transpose-1d.cu | 6 +- ggml/src/ggml-cuda/convert.cu | 2 +- ggml/src/ggml-cuda/fattn-common.cuh | 9 +-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 85 ++++++++++++++++--------- ggml/src/ggml-cuda/fattn-tile-f16.cu | 14 +++- ggml/src/ggml-cuda/fattn-tile-f32.cu | 12 ++++ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 14 +++- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 10 +++ ggml/src/ggml-cuda/fattn-wmma-f16.cu | 12 +++- ggml/src/ggml-cuda/mma.cuh | 2 + ggml/src/ggml-cuda/mmq.cuh | 60 ++++++++++------- ggml/src/ggml-cuda/mmv.cu | 2 +- ggml/src/ggml-cuda/mmvq.cu | 6 +- ggml/src/ggml-cuda/pad.cu | 2 +- ggml/src/ggml-cuda/upscale.cu | 2 +- 18 files changed, 189 insertions(+), 75 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6c02b69e..086c822d 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -158,6 +158,12 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP +#ifdef _MSC_VER +#define GGML_EXTENSION +#else // _MSC_VER +#define GGML_EXTENSION __extension__ +#endif // _MSC_VER + #define QK4_0 32 typedef struct { ggml_half d; // delta @@ -167,7 +173,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b #define QK4_1 32 typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // delta ggml_half m; // min @@ -188,7 +194,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 #define QK5_1 32 typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // delta ggml_half m; // min @@ -209,7 +215,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block #define QK8_1 32 typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // delta ggml_half s; // d * sum(qs[i]) @@ -250,7 +256,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants - union { + GGML_EXTENSION union { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins @@ -277,7 +283,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12 // weight is represented as x = a * q + b // Effectively 4.5 bits per weight typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins @@ -294,7 +300,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, // weight is represented as x = a * q + b // Effectively 5.5 bits per weight typedef struct { - union { + GGML_EXTENSION union { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index f8c55a2b..a718b6a1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -288,6 +288,10 @@ static __device__ void no_device_code( __trap(); GGML_UNUSED(no_device_code); // suppress unused function warning + +#if defined(GGML_USE_MUSA) + __builtin_unreachable(); +#endif // defined(GGML_USE_MUSA) } #ifdef __CUDA_ARCH__ diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index aafbaf80..e9ffd274 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -38,7 +38,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float * blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (blockIdx.y < ne01) { // src0 + if (blockIdx.y < (unsigned)ne01) { // src0 int offset_src = nidx + blockIdx.y * ne0 + @@ -64,7 +64,7 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float * blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (blockIdx.z < ne02) { // src0 + if (blockIdx.z < (unsigned)ne02) { // src0 int offset_src = nidx + blockIdx.y * ne0 + diff --git a/ggml/src/ggml-cuda/conv-transpose-1d.cu b/ggml/src/ggml-cuda/conv-transpose-1d.cu index b1e94d6f..fe4caf67 100644 --- a/ggml/src/ggml-cuda/conv-transpose-1d.cu +++ b/ggml/src/ggml-cuda/conv-transpose-1d.cu @@ -34,6 +34,10 @@ static __global__ void conv_transpose_1d_kernel( } } dst[global_index] = accumulator; + GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3); + GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3); + GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1); + GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2); } static void conv_transpose_1d_f32_f32_cuda( @@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor const int p0 = 0;//opts[3]; const int d0 = 1;//opts[4]; - const int64_t kernel_size = ggml_nelements(src0); - const int64_t input_size = ggml_nelements(src1); const int64_t output_size = ggml_nelements(dst); conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 795b720d..2997e2b4 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -577,7 +577,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res return; } - const src_t * x = (src_t *) vx; + const src_t * x = (const src_t *) vx; y[i] = x[i]; } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1c2a2a13..3fe22092 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( float vals[sizeof(int)] = {0.0f}; #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { vals[l] = scale * x[4*threadIdx.x + l]; } float amax = fabsf(vals[0]); float sum = vals[0]; #pragma unroll - for (int l = 1; l < sizeof(int); ++l) { + for (int l = 1; l < int(sizeof(int)); ++l) { amax = fmaxf(amax, fabsf(vals[l])); sum += vals[l]; } @@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( if (d != 0.0f) { #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { q8[l] = roundf(vals[l] / d); } } @@ -638,7 +638,7 @@ static __global__ void flash_attn_combine_results( float VKQ_denominator = 0.0f; for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; - const float KQ_max_scale = expf(diff); + float KQ_max_scale = expf(diff); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; @@ -649,6 +649,7 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; } +[[noreturn]] static void on_no_fattn_vec_case(const int D) { if (D == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 024032f6..04804a15 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // CP_ASYNC_AVAILABLE #else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); + GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); + GGML_UNUSED(kb0); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } #else + GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); + GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); + GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } @@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64); -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64); +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) // Kernels with ncols == 128 are only 4% faster due to register pressure. -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128); -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory. +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) +// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 77455d8e..e0039e17 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -282,7 +282,19 @@ static __global__ void flash_attn_tile_ext_f16( } } #else - NO_DEVICE_CODE; + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 85fea440..81290c90 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -281,6 +281,18 @@ static __global__ void flash_attn_tile_ext_f32( } } #else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 32c52ebe..e17d2d0e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -292,7 +292,19 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else - NO_DEVICE_CODE; + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); + GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); + GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); + GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 336c136d..70487485 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -277,6 +277,16 @@ static __global__ void flash_attn_vec_ext_f32( dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); + GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); + GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); + GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 5c214ea3..bc21b27a 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -430,7 +430,17 @@ static __global__ void flash_attn_ext_f16( dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; } #else - NO_DEVICE_CODE; + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); + GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); + NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 9206bfeb..2af63355 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(ret) : "r"(x)); #else + GGML_UNUSED(x); NO_DEVICE_CODE; #endif // defined(NEW_MMA_AVAILABLE) return ret; @@ -178,6 +179,7 @@ namespace ggml_cuda_mma { : "l"(xs)); #else load_generic(xs0, stride); + GGML_UNUSED(t); #endif // NEW_MMA_AVAILABLE } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index f136c419..53235801 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (k01 < WARP_SIZE/2) { - constexpr int ns = 2; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, - &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } else { - constexpr int ns = 1; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, - &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); - } + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } + + // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. + // As a workaround 2 separate loops are used instead. +#pragma unroll + for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } @@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -1253,7 +1268,7 @@ template static __device__ __forceinlin const float d = bxi->d; #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; } #else @@ -1376,7 +1391,7 @@ template static __device__ __forceinlin const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); } } @@ -1517,7 +1532,7 @@ template static __device__ __forceinlin const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int l = 0; l < sizeof(int); ++l) { + for (int l = 0; l < int(sizeof(int)); ++l) { x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); } } @@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile( } else { write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j); } + + GGML_UNUSED(ne00); GGML_UNUSED(ne10); } @@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: - if (it != blockIdx.x || jt != blockIdx.y) { + if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) { continue; } @@ -2825,7 +2842,6 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); - const int nsm = ggml_cuda_info().devices[id].nsm; const int cc = ggml_cuda_info().devices[id].cc; const int smpbo = ggml_cuda_info().devices[id].smpbo; diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index f89ed03b..b39961cd 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -29,7 +29,7 @@ static __global__ void mul_mat_vec( __syncthreads(); } - float sumf; + float sumf = 0.0f; if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 45ea30f6..eef8585a 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -151,7 +151,7 @@ static __global__ void mul_mat_vec_q( constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; // partial sum for each thread - float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; + float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = (const block_q8_1 *) vy; @@ -197,10 +197,12 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) { dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; } } + + GGML_UNUSED(nrows_x); } static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) { diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index aba539e8..77432b04 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { + if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) { int offset_src = nidx + blockIdx.y * ne00 + diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index cf513c3a..524e9795 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -19,7 +19,7 @@ static __global__ void upscale_f32(const float * x, float * dst, int i02 = i12 / sf2; int i03 = i13 / sf3; - dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); + dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ); } static void upscale_f32_cuda(const float * x, float * dst,