diff --git a/ggml-metal.m b/ggml-metal.m
index 86426e93..7f0f1f1f 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -1941,7 +1941,12 @@ static enum ggml_status ggml_metal_graph_compute(
                                     {
                                         nth0 = 4;
                                         nth1 = 16;
+                                    #if QK_K == 64
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+                                    #else
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+                                    #endif
+
                                     } break;
                                 default:
                                     {
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 7f840ab0..79cce21f 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -864,15 +864,16 @@ void mul_vec_q_n_f32_impl(
         device const void  * src0,
         device const float * src1,
         device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
+                   constant int64_t &   ne00,
+                   constant int64_t &   ne01,
+                   constant int64_t &   ne02,
+                   constant int64_t &   ne10,
+                   constant int64_t &   ne12,
+                   constant int64_t &   ne0,
+                   constant int64_t &   ne1,
+                   constant uint &      r2,
+                   constant uint &      r3,
+                   threadgroup int8_t * shared_values,
                    uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
 
@@ -949,7 +950,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -975,7 +976,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -1001,7 +1002,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -1027,7 +1028,7 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 
@@ -1046,6 +1047,7 @@ void kernel_mul_mv_q8_0_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1126,7 +1128,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 #define N_F32_F32 4
@@ -2716,6 +2718,7 @@ void kernel_mul_mv_q2_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2878,7 +2881,7 @@ kernel void kernel_mul_mv_q2_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 #if QK_K == 256
@@ -2895,6 +2898,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3053,6 +3057,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3142,7 +3147,7 @@ kernel void kernel_mul_mv_q3_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 #if QK_K == 256
@@ -3159,6 +3164,7 @@ void kernel_mul_mv_q4_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3272,6 +3278,7 @@ void kernel_mul_mv_q4_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3380,7 +3387,7 @@ kernel void kernel_mul_mv_q4_K_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q5_K_f32_impl(
@@ -3396,6 +3403,7 @@ void kernel_mul_mv_q5_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3586,7 +3594,7 @@ kernel void kernel_mul_mv_q5_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q6_K_f32_impl(
@@ -3602,6 +3610,7 @@ void kernel_mul_mv_q6_K_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3720,7 +3729,7 @@ kernel void kernel_mul_mv_q6_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
@@ -4403,6 +4412,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4492,6 +4502,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4600,11 +4611,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
-        threadgroup float  * shared_values [[threadgroup(0)]],
+        threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
+    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
     const int nb = ne00/QK4_NL;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
@@ -4694,11 +4706,11 @@ void kernel_mul_mv_iq4_xs_f32_impl(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
-        threadgroup float  * shared_values [[threadgroup(0)]],
+        threadgroup int8_t  * shared_values_i8 [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
+    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
     const int nb = ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
@@ -4801,7 +4813,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -4829,7 +4841,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -4853,7 +4865,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
-        threadgroup float * shared_values [[threadgroup(0)]],
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -4882,7 +4894,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         constant   int64_t & ne1,
         constant   uint    & r2,
         constant   uint    & r3,
-        threadgroup float * shared_values [[threadgroup(0)]],
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -6029,8 +6041,139 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel
 // matrix-vector multiplication
 //
 
-[[host_name("kernel_mul_mv_id_f32_f32")]]
-kernel void kernel_mul_mv_id_f32_f32(
+typedef void (kernel_mul_mv_impl_t)(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]]);
+
+typedef void (kernel_mul_mv2_impl_t)(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   uint    & r2,
+        constant   uint    & r3,
+        threadgroup int8_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]);
+
+template<kernel_mul_mv_impl_t impl_fn>
+void mmv_fn(
+        device const    char * src0,
+        device const    char * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+}
+
+template<kernel_mul_mv2_impl_t impl_fn>
+void mmv_fn(
+        device const    char * src0,
+        device const    char * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+}
+
+typedef void (mul_mv_impl_fn_t)(
+        device const    char * src0,
+        device const    char * src1,
+        device         float * dst,
+        constant     int64_t & ne00,
+        constant     int64_t & ne01,
+        constant     int64_t & ne02,
+        constant    uint64_t & nb00,
+        constant    uint64_t & nb01,
+        constant    uint64_t & nb02,
+        constant     int64_t & ne10,
+        constant     int64_t & ne11,
+        constant     int64_t & ne12,
+        constant     int64_t & ne13,
+        constant    uint64_t & nb10,
+        constant    uint64_t & nb11,
+        constant    uint64_t & nb12,
+        constant     int64_t & ne0,
+        constant     int64_t & ne1,
+        constant    uint64_t & nb1,
+        constant        uint & r2,
+        constant        uint & r3,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
+        uint3                  tgpig[[threadgroup_position_in_grid]],
+        uint                   tiitg[[thread_index_in_threadgroup]],
+        uint                   tiisg[[thread_index_in_simdgroup]],
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]);
+
+template<mul_mv_impl_fn_t impl_fn>
+kernel void kernel_mul_mv_id(
         device const    char * src0s,
         device const    char * src1,
         device         float * dst,
@@ -6055,6 +6198,7 @@ kernel void kernel_mul_mv_id_f32_f32(
         constant        uint & r2,
         constant        uint & r3,
         constant         int & idx,
+        threadgroup int8_t   * shared_values [[threadgroup(0)]],
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
@@ -6066,7 +6210,7 @@ kernel void kernel_mul_mv_id_f32_f32(
     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
     device const char * src0 = src0s + id*nb02;
 
-    kernel_mul_mv_f32_f32_impl(
+    impl_fn(
         src0,
         src1 + bid*nb11,
         dst  + bid*ne0,
@@ -6079,630 +6223,23 @@ kernel void kernel_mul_mv_id_f32_f32(
         ne10,
         ne11,
         ne12,
+        ne13,
         nb10,
         nb11,
         nb12,
         ne0,
         ne1,
+        nb1,
         r2,
         r3,
+        shared_values,
         tgpig,
-        tiisg);
-}
-
-[[host_name("kernel_mul_mv_id_f16_f32")]]
-kernel void kernel_mul_mv_id_f16_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_f16_f32_impl(
-        src0,
-        src1 + bid*nb11,
-        dst  + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        nb00,
-        nb01,
-        nb02,
-        ne10,
-        ne11,
-        ne12,
-        nb10,
-        nb11,
-        nb12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg);
-}
-
-[[host_name("kernel_mul_mv_id_q8_0_f32")]]
-kernel void kernel_mul_mv_id_q8_0_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_q8_0_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
+        tiitg,
         tiisg,
         sgitg);
 }
 
-[[host_name("kernel_mul_mv_id_q4_0_f32")]]
-kernel void kernel_mul_mv_id_q4_0_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_1_f32")]]
-kernel void kernel_mul_mv_id_q4_1_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_0_f32")]]
-kernel void kernel_mul_mv_id_q5_0_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_1_f32")]]
-kernel void kernel_mul_mv_id_q5_1_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q2_K_f32")]]
-kernel void kernel_mul_mv_id_q2_K_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_q2_K_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q3_K_f32")]]
-kernel void kernel_mul_mv_id_q3_K_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_q3_K_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q4_K_f32")]]
-kernel void kernel_mul_mv_id_q4_K_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_q4_K_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q5_K_f32")]]
-kernel void kernel_mul_mv_id_q5_K_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_q5_K_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_q6_K_f32")]]
-kernel void kernel_mul_mv_id_q6_K_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_q6_K_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xxs_f32(
+typedef void (kernel_mul_mv_id_t)(
         device const    char * src0s,
         device const    char * src1,
         device         float * dst,
@@ -6731,485 +6268,29 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
         uint3                  tgpig[[threadgroup_position_in_grid]],
         uint                   tiitg[[thread_index_in_threadgroup]],
         uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
+        uint                   sgitg[[simdgroup_index_in_threadgroup]]);
 
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq2_xxs_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
-kernel void kernel_mul_mv_id_iq2_xs_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq2_xs_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
-kernel void kernel_mul_mv_id_iq3_xxs_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq3_xxs_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
-kernel void kernel_mul_mv_id_iq3_s_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq3_s_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
-kernel void kernel_mul_mv_id_iq2_s_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq2_s_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
-kernel void kernel_mul_mv_id_iq1_s_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq1_s_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
-kernel void kernel_mul_mv_id_iq1_m_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq1_m_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
-kernel void kernel_mul_mv_id_iq4_nl_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup float    * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-    kernel_mul_mv_iq4_nl_f32_impl(
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
-
-[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
-kernel void kernel_mul_mv_id_iq4_xs_f32(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        constant        uint & r2,
-        constant        uint & r3,
-        constant         int & idx,
-        threadgroup float    * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int64_t bid = tgpig.z/(ne12*ne13);
-
-    tgpig.z = tgpig.z%(ne12*ne13);
-
-    const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
-    device const char * src0 = src0s + id*nb02;
-
-#if QK_K == 64
-    kernel_mul_mv_iq4_nl_f32_impl(
-#else
-    kernel_mul_mv_iq4_xs_f32_impl(
+template [[host_name("kernel_mul_mv_id_f32_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
+#if QK_K != 64
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
 #endif
-        src0,
-        (device const float *) (src1 + bid*nb11),
-        dst + bid*ne0,
-        ne00,
-        ne01,
-        ne02,
-        ne10,
-        ne12,
-        ne0,
-        ne1,
-        r2,
-        r3,
-        shared_values,
-        tgpig,
-        tiisg,
-        sgitg);
-}
+
diff --git a/ggml.c b/ggml.c
index 2c4b8ec4..ba06665a 100644
--- a/ggml.c
+++ b/ggml.c
@@ -11074,7 +11074,6 @@ static void ggml_compute_forward_mul_mat_id(
         }
 
         // initialize matrix_row_counts
-        GGML_ASSERT(wdata == wdata_src1_end);
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
 
         // group rows by src0 matrix