From 38ac47cd4d8876e28f509254f9c70e896c89466c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 22 Feb 2025 20:44:34 +0100 Subject: [PATCH] CUDA: app option to compile without FlashAttention (llama/12025) --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-cuda/CMakeLists.txt | 4 ++++ ggml/src/ggml-cuda/common.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 8 ++++---- ggml/src/ggml-cuda/fattn-tile-f16.cu | 9 ++------- ggml/src/ggml-cuda/fattn-tile-f32.cu | 8 ++++---- ggml/src/ggml-cuda/fattn-vec-f16.cuh | 9 ++------- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 8 ++++---- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 ++-- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-hip/CMakeLists.txt | 4 ++++ ggml/src/ggml-musa/CMakeLists.txt | 4 ++++ 12 files changed, 34 insertions(+), 31 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index fc5eac15..12afe0f2 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -151,6 +151,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) +option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index e63ede2f..96bd5a0b 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() + if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) + endif() + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7e99838c..adf0d3ec 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -204,9 +204,9 @@ typedef float2 dfloat2; #define CP_ASYNC_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE -#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) #define FLASH_ATTN_AVAILABLE -#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2e0db9a..718ee540 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#ifndef NEW_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // NEW_MMA_AVAILABLE +#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16( flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); +#else + NO_DEVICE_CODE; +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index b8b415ef..ef3569fa 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16( const int ne1, const int ne2, const int ne3) { -#ifdef FP16_AVAILABLE - -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: #ifdef FP16_MMA_AVAILABLE @@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16( } #else NO_DEVICE_CODE; -#endif // FP16_AVAILABLE +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 4352a284..04b69c83 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: #ifdef FP16_MMA_AVAILABLE @@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32( dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } +#else + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e758a0f6..b7686c1e 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#ifdef FP16_AVAILABLE - -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16( } #else NO_DEVICE_CODE; -#endif // FP16_AVAILABLE +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 134144a3..c1d2dd8d 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { -#ifndef FLASH_ATTN_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FLASH_ATTN_AVAILABLE +#ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32( if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } +#else + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index de38470a..8828652f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16( } #else NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f6854232..ebb2ccae 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_FLASH_ATTN_EXT: { #ifndef FLASH_ATTN_AVAILABLE return false; -#endif +#endif // FLASH_ATTN_AVAILABLE if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f4a46836..4a0384dd 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM) add_compile_definitions(GGML_HIP_NO_VMM) endif() +if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) +endif() + if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-hip PRIVATE hip::device) diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 1bfc07c5..2c75abf6 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() + if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) + endif() + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif()