mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-28 19:25:56 +00:00
CUDA: app option to compile without FlashAttention (llama/12025)
This commit is contained in:
parent
2d70cd36d7
commit
38ac47cd4d
@ -151,6 +151,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"ggml: max. batch size for using peer access")
|
||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
|
||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
|
@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
@ -204,9 +204,9 @@ typedef float2 dfloat2;
|
||||
#define CP_ASYNC_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static bool fp16_available(const int cc) {
|
||||
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
|
||||
|
@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef NEW_MMA_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int ncols1, int ncols2>
|
||||
|
@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
|
@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||
|
@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
|
@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||
|
@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
NO_DEVICE_CODE;
|
||||
@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
|
@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
|
||||
add_compile_definitions(GGML_HIP_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
target_link_libraries(ggml-hip PRIVATE hip::device)
|
||||
|
@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
|
||||
add_compile_definitions(GGML_CUDA_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
||||
add_compile_definitions(GGML_CUDA_F16)
|
||||
endif()
|
||||
|
Loading…
x
Reference in New Issue
Block a user