From fad2806352c58beca08667e30b2830c9ba192932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 2 Feb 2025 23:48:29 +0100 Subject: [PATCH] HIP: fix flash_attn_stream_k_fixup warning (llama/11604) --- ggml/src/ggml-cuda/fattn-common.cuh | 10 ++++++++++ ggml/src/ggml-cuda/softmax.cu | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index cfd7c0f4..d40ee2da 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,6 +516,12 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } +// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -614,6 +620,10 @@ static __global__ void flash_attn_stream_k_fixup( } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index da377200..aac6e099 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32(half val) { #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" -#endif +#endif // __clang__ template static __global__ void soft_max_f32( const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, @@ -126,7 +126,7 @@ static __global__ void soft_max_f32( } #ifdef __clang__ #pragma clang diagnostic pop -#endif +#endif // __clang__ static __global__ void soft_max_back_f32( const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {