diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 3942013f..456e1fd9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node( // ne00*(nsg) // each simdgroup has a full f16 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 80d0765b..b08666e2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S[Q] = { [0 ... Q-1] = 0.0f }; - half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + float S[Q] = { [0 ... Q-1] = 0.0f }; + float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; // thread indices inside the simdgroup // TODO: see if we can utilize quad-group functions for better performance @@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext( const bool has_mask = mask != q; - half slope = 1.0f; + float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); @@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext( if (has_mask) { // used to detect blocks full of -INF - half smax = -INFINITY; + float smax = -INFINITY; // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); - const half m = pm[ic + tiisg]; + const float m = pm[ic + tiisg]; ss[j*TS + C + tiisg] = m; smax = max(smax, m); @@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext( // online softmax { for (ushort j = 0; j < Q; ++j) { - const half m = M[j]; + const float m = M[j]; // scale and apply the logitcap / mask - half s = ss[j*TS + tiisg]*args.scale; + float s = ss[j*TS + tiisg]*args.scale; if (args.logit_softcap != 0.0f) { s = args.logit_softcap*precise::tanh(s); @@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext( M[j] = simd_max(max(M[j], s)); - const half ms = exp(m - M[j]); - const half vs = exp(s - M[j]); + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); S[j] = S[j]*ms + simd_sum(vs); @@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - half S = { 0.0f }; - half M = { -__FLT16_MAX__/2 }; + float S = { 0.0f }; + float M = { -__FLT16_MAX__/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const half S0 = ss[j*TS + 0]; - const half S1 = ss[j*TS + sg*SH + 0]; + const float S0 = ss[j*TS + 0]; + const float S1 = ss[j*TS + sg*SH + 0]; - const half M0 = ss[j*TS + 1]; - const half M1 = ss[j*TS + sg*SH + 1]; + const float M0 = ss[j*TS + 1]; + const float M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); S = S0*ms0 + S1*ms1; @@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec( constexpr short DV4 = DV/4; constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads - constexpr short SH = 2*C; // shared memory per simdgroup + constexpr short SH = 4*C; // shared memory per simdgroup const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results // store the result for all queries in local memory (the O matrix from the paper) o4_t lo[DV4/NL]; @@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec( threadgroup_barrier(mem_flags::mem_threadgroup); { - half S = 0.0f; - half M = -__FLT16_MAX__/2; + float S = 0.0f; + float M = -__FLT16_MAX__/2; // thread indices inside the simdgroup const short tx = tiisg%NL; @@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec( // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31); - half slope = 1.0f; + float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const float base = h < args.n_head_log2 ? args.m0 : args.m1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); @@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec( // online softmax { - const half m = M; - const half s = ss[tiisg]; + const float m = M; + const float s = ss[tiisg]; M = simd_max(max(M, s)); - const half ms = exp(m - M); - const half vs = exp(s - M); + const float ms = exp(m - M); + const float vs = exp(s - M); S = S*ms + simd_sum(vs); @@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec( v4_t mv; deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); - lo[ii/NL] += mv*ms; + lo[ii/NL] += o4_t(float4(mv)*float4(ms)); } } } @@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec( // parallel reduce for (short r = nsg/2; r > 0; r >>= 1) { if (sgitg < r) { - const half S0 = ss[ 0]; - const half S1 = ss[r*SH + 0]; + const float S0 = ss[ 0]; + const float S1 = ss[r*(SH/2) + 0]; - const half M0 = ss[ 1]; - const half M1 = ss[r*SH + 1]; + const float M0 = ss[ 1]; + const float M1 = ss[r*(SH/2) + 1]; - const half M = max(M0, M1); + const float M = max(M0, M1); - const half ms0 = exp(M0 - M); - const half ms1 = exp(M1 - M); + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); - const half S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[0] = S; @@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec( // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // #define FA_TYPES \ - half4, \ - half4, \ - half4, \ - float, \ - half, half4, \ + half4, \ + half4, \ + half4, \ + float, \ + float, float4, \ half4 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;