From 42938398f90bdd23d5797c6189663b46b80d3ab1 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Fri, 2 May 2025 22:23:12 +0530 Subject: [PATCH] ggml : Enable MMA for BF16 in llamafile_sgemm (llama/13148) This patch upstreams llamafile's cpu matrix multiplication kernels for ppc64le using MMA builtins for BF16 data type. This change results in 9x - 40x gains in total speed S t/s (ie all tokens/total time), across various batch sizes tested using llama-batched-bench benchmark. The patch is tested with Meta-Lllama-3-8B, and Mistral-7B models (BF16 models generated by using llama-quantize from corresponding FP32 models) on an IBM POWER10 machine. Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 501 ++++++++++++++++++++++++++ 1 file changed, 501 insertions(+) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index f6374f78..1d46158f 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +class tinyBLAS_BF16_PPC { + public: + tinyBLAS_BF16_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + if (numVec == 2) { + t[0] = vec_perm(c[0], c[1], swiz1); + t[1] = vec_perm(c[2], c[3], swiz1); + s[0] = vec_perm(t[0], t[1], swiz3); + s[1] = vec_perm(t[0], t[1], swiz4); + vec_xst(s[0], 0, (vec_t*)vecOffset); + vec_xst(s[1], 0, (vec_t*)(vecOffset + 16)); + } else if (numVec == 4) { + t[0] = vec_perm(c[0], c[1], swiz1); + t[1] = vec_perm(c[0], c[1], swiz2); + t[2] = vec_perm(c[2], c[3], swiz1); + t[3] = vec_perm(c[2], c[3], swiz2); + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + for (int i = 0; i < 4; ++i) + vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16)); + } else if (numVec == 8) { + for (int i = 0; i < 4; i += 2) { + t[i+0] = vec_perm(c[i+0], c[i+1], swiz1); + t[i+1] = vec_perm(c[i+0], c[i+1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i+0] = vec_perm(c[i+0], c[i+1], swiz1); + t[i+1] = vec_perm(c[i+0], c[i+1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) + vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16)); + } + } + + void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) { + int64_t i, j; + TA *aoffset = NULL; + unsigned char *vecOffset = NULL; + TA * aoffsets[8]; + vector unsigned char c_arr[8]; + aoffset = const_cast(a); + vecOffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + if (cols == 4) { + aoffsets[0] = aoffset; + for (int it = 1; it < 4; ++it) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; + for (int i = 0; i < 4; ++i) + c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]); + vector_permute_store(c_arr, 4, vecOffset); + for (int i = 0; i<4; i++) + aoffsets[i] = aoffsets[i]+lda; + vecOffset +=64; + } + i = (cols >> 3); + if (i > 0) { + aoffsets[0] = aoffset; + for (int it = 1; it < 8; ++it) { + aoffsets[it] = aoffsets[it-1] + lda; + } + aoffset += 8 * lda; + do { + for (int it = 0; it < 8; ++it) + c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]); + vector_permute_store(c_arr, 8, vecOffset); + for (int it = 0; it < 8; ++it) + aoffsets[it] = aoffsets[it] + 8*lda; + vecOffset += 128; + i--; + } while(i > 0); + } + j--; + } while(j > 0); + } + if (rows & 4) { + aoffsets[0] = aoffset; + for (int it = 1; it < 4; ++it) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; + if (cols == 4) { + for (int it = 0; it < 4; ++it) + c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]); + vector_permute_store(c_arr, 2, vecOffset); + for (int it = 0; it< 4; it++) + aoffsets[it] = aoffsets[it] + lda; + vecOffset += 32; + } + i = (cols >> 3); + if (i > 0) { + do { + for (int it = 0; it < 4; ++it) + c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]); + vector_permute_store(c_arr, 4, vecOffset); + for (int it = 0; it< 4; it++) + aoffsets[it] = aoffsets[it] + 8*lda; + vecOffset += 64; + i--; + } while(i > 0); + } + } + if (rows & 3) { + aoffsets[0] = aoffset; + for (int it = 1; it < 4; ++it) + aoffsets[it] = aoffsets[it-1] + lda; + if (cols == 4) { + switch(rows) { + case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]); + case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]); + case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]); + break; + } + vector_permute_store(c_arr, 2, vecOffset); + for (int it = 0; it< 4; it++) + aoffsets[it] = aoffsets[it] + lda; + vecOffset += 32; + } + i = (cols >> 3); + if (i > 0) { + do { + switch(rows) { + case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]); + case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]); + case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]); + break; + } + vector_permute_store(c_arr, 4, vecOffset); + for (int it = 0; it <4; it++) + aoffsets[it] = aoffsets[it] + 8* lda; + vecOffset += 64; + i--; + } while(i > 0); + } + } + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >=8 && n_rem >=4){ + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem >= 8)) { + nc = 8; + switch(m_rem) { + case 1: + mc = 1; + gemm_Mx8<1>(m0, m, n0, n); + break; + case 2: + mc = 2; + gemm_Mx8<2>(m0, m, n0, n); + break; + case 3: + mc = 3; + gemm_Mx8<3>(m0, m, n0, n); + break; + default: + return; + } + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm_small<4, 4>(m0, m, n0, n); + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 2: + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 3: + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small<3, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small<3, 1>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small<2,4>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small<2, 2>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small<2, 1>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small<1, 3>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small<1, 1>(m0, m, n0, n); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[8] , vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int l = 0; l < k; l+=8) { + packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[4] , vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int l = 0; l < k; l+=8) { + packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); + packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii+4, jj); + } + + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[8], vec_C[4]; + acc_t acc_0, acc_1, acc_2, acc_3; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + for (int l = 0; l < k; l+=8) { + packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); + packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); + __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + } + } + + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + SAVE_ACC(&acc_2, ii+4, jj); + SAVE_ACC(&acc_3, ii+4, jj+4); + } + + template + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + vec_t vec_A[2], vec_B[2]; + for (int l=0; l + void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int RN = 8; + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + vec_t vec_A[4], vec_B[8]; + for (int l=0; l + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + kernel(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; + template class tinyBLAS_Q0_PPC { public: @@ -2202,6 +2689,7 @@ class tinyBLAS_PPC { boffset = vec; j = (rows >> 3); if (j > 0) { + do { aoffset1 = aoffset; aoffset2 = aoffset1 + lda; @@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 (float *)C, ldc}; return tb.matmul(m, n); } +#elif defined(__MMA__) + if ((k % 8)) + return false; + if(Btype == GGML_TYPE_BF16) { + tinyBLAS_BF16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; + } #endif return false; } + case GGML_TYPE_F16: { #if defined(__AVX512F__) if (Btype == GGML_TYPE_F16) {