From fd5a3e1bc61710ed81ae0a4df86151c413c567fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 24 Apr 2025 15:57:10 +0200 Subject: [PATCH] CUDA: use switch statements in constexpr functions (llama/13095) --- ggml/src/ggml-cuda/mmq.cuh | 80 ++++++++++++++++++++------------------ ggml/src/ggml-cuda/mmvq.cu | 80 ++++++++++++++++++++------------------ 2 files changed, 84 insertions(+), 76 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 53235801..3cb20155 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() { #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { - return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : - type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : - type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : - tile_x_sizes{0, 0, 0}; + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } } #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) @@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : - 0; + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index cac04916..d846e35a 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -7,47 +7,51 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : - type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : - type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : - type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : - type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : - type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : - type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : - type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : - type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : - type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : - type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : - type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : - type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : - type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : - type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : - type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : - type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : - nullptr; + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } } static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : - 1; + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } } enum mmvq_parameter_table_id {