mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-09 20:13:14 +00:00
ggml : Enable MMA for BF16 in llamafile_sgemm (llama/13148)
This patch upstreams llamafile's cpu matrix multiplication kernels for ppc64le using MMA builtins for BF16 data type. This change results in 9x - 40x gains in total speed S t/s (ie all tokens/total time), across various batch sizes tested using llama-batched-bench benchmark. The patch is tested with Meta-Lllama-3-8B, and Mistral-7B models (BF16 models generated by using llama-quantize from corresponding FP32 models) on an IBM POWER10 machine. Signed-off-by: Shalini Salomi Bodapati <Shalini.Salomi.Bodapati@ibm.com>
This commit is contained in:
parent
a8fe90ae15
commit
42938398f9
@ -1054,6 +1054,493 @@ class tinyBLAS_Q0_AVX {
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
template <typename TA, typename TB, typename TC>
|
||||||
|
class tinyBLAS_BF16_PPC {
|
||||||
|
public:
|
||||||
|
tinyBLAS_BF16_PPC(int64_t k,
|
||||||
|
const TA *A, int64_t lda,
|
||||||
|
const TB *B, int64_t ldb,
|
||||||
|
TC *C, int64_t ldc,
|
||||||
|
int ith, int nth)
|
||||||
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void matmul(int64_t m, int64_t n) {
|
||||||
|
mnpack(0, m, 0, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) {
|
||||||
|
vec_t t[8], s[8];
|
||||||
|
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||||
|
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||||
|
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||||
|
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||||
|
|
||||||
|
if (numVec == 2) {
|
||||||
|
t[0] = vec_perm(c[0], c[1], swiz1);
|
||||||
|
t[1] = vec_perm(c[2], c[3], swiz1);
|
||||||
|
s[0] = vec_perm(t[0], t[1], swiz3);
|
||||||
|
s[1] = vec_perm(t[0], t[1], swiz4);
|
||||||
|
vec_xst(s[0], 0, (vec_t*)vecOffset);
|
||||||
|
vec_xst(s[1], 0, (vec_t*)(vecOffset + 16));
|
||||||
|
} else if (numVec == 4) {
|
||||||
|
t[0] = vec_perm(c[0], c[1], swiz1);
|
||||||
|
t[1] = vec_perm(c[0], c[1], swiz2);
|
||||||
|
t[2] = vec_perm(c[2], c[3], swiz1);
|
||||||
|
t[3] = vec_perm(c[2], c[3], swiz2);
|
||||||
|
s[0] = vec_perm(t[0], t[2], swiz3);
|
||||||
|
s[1] = vec_perm(t[0], t[2], swiz4);
|
||||||
|
s[2] = vec_perm(t[1], t[3], swiz3);
|
||||||
|
s[3] = vec_perm(t[1], t[3], swiz4);
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
||||||
|
} else if (numVec == 8) {
|
||||||
|
for (int i = 0; i < 4; i += 2) {
|
||||||
|
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
||||||
|
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
||||||
|
}
|
||||||
|
for (int i = 4; i < 8; i += 2) {
|
||||||
|
t[i+0] = vec_perm(c[i+0], c[i+1], swiz1);
|
||||||
|
t[i+1] = vec_perm(c[i+0], c[i+1], swiz2);
|
||||||
|
}
|
||||||
|
s[0] = vec_perm(t[0], t[2], swiz3);
|
||||||
|
s[1] = vec_perm(t[0], t[2], swiz4);
|
||||||
|
s[2] = vec_perm(t[1], t[3], swiz3);
|
||||||
|
s[3] = vec_perm(t[1], t[3], swiz4);
|
||||||
|
s[4] = vec_perm(t[4], t[6], swiz3);
|
||||||
|
s[5] = vec_perm(t[4], t[6], swiz4);
|
||||||
|
s[6] = vec_perm(t[5], t[7], swiz3);
|
||||||
|
s[7] = vec_perm(t[5], t[7], swiz4);
|
||||||
|
for (int i = 0; i < 8; ++i)
|
||||||
|
vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) {
|
||||||
|
int64_t i, j;
|
||||||
|
TA *aoffset = NULL;
|
||||||
|
unsigned char *vecOffset = NULL;
|
||||||
|
TA * aoffsets[8];
|
||||||
|
vector unsigned char c_arr[8];
|
||||||
|
aoffset = const_cast<TA*>(a);
|
||||||
|
vecOffset = vec;
|
||||||
|
j = (rows >> 3);
|
||||||
|
if (j > 0) {
|
||||||
|
do {
|
||||||
|
if (cols == 4) {
|
||||||
|
aoffsets[0] = aoffset;
|
||||||
|
for (int it = 1; it < 4; ++it)
|
||||||
|
aoffsets[it] = aoffsets[it-1] + lda;
|
||||||
|
aoffset += 4 * lda;
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]);
|
||||||
|
vector_permute_store(c_arr, 4, vecOffset);
|
||||||
|
for (int i = 0; i<4; i++)
|
||||||
|
aoffsets[i] = aoffsets[i]+lda;
|
||||||
|
vecOffset +=64;
|
||||||
|
}
|
||||||
|
i = (cols >> 3);
|
||||||
|
if (i > 0) {
|
||||||
|
aoffsets[0] = aoffset;
|
||||||
|
for (int it = 1; it < 8; ++it) {
|
||||||
|
aoffsets[it] = aoffsets[it-1] + lda;
|
||||||
|
}
|
||||||
|
aoffset += 8 * lda;
|
||||||
|
do {
|
||||||
|
for (int it = 0; it < 8; ++it)
|
||||||
|
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
||||||
|
vector_permute_store(c_arr, 8, vecOffset);
|
||||||
|
for (int it = 0; it < 8; ++it)
|
||||||
|
aoffsets[it] = aoffsets[it] + 8*lda;
|
||||||
|
vecOffset += 128;
|
||||||
|
i--;
|
||||||
|
} while(i > 0);
|
||||||
|
}
|
||||||
|
j--;
|
||||||
|
} while(j > 0);
|
||||||
|
}
|
||||||
|
if (rows & 4) {
|
||||||
|
aoffsets[0] = aoffset;
|
||||||
|
for (int it = 1; it < 4; ++it)
|
||||||
|
aoffsets[it] = aoffsets[it-1] + lda;
|
||||||
|
aoffset += 4 * lda;
|
||||||
|
if (cols == 4) {
|
||||||
|
for (int it = 0; it < 4; ++it)
|
||||||
|
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
||||||
|
vector_permute_store(c_arr, 2, vecOffset);
|
||||||
|
for (int it = 0; it< 4; it++)
|
||||||
|
aoffsets[it] = aoffsets[it] + lda;
|
||||||
|
vecOffset += 32;
|
||||||
|
}
|
||||||
|
i = (cols >> 3);
|
||||||
|
if (i > 0) {
|
||||||
|
do {
|
||||||
|
for (int it = 0; it < 4; ++it)
|
||||||
|
c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]);
|
||||||
|
vector_permute_store(c_arr, 4, vecOffset);
|
||||||
|
for (int it = 0; it< 4; it++)
|
||||||
|
aoffsets[it] = aoffsets[it] + 8*lda;
|
||||||
|
vecOffset += 64;
|
||||||
|
i--;
|
||||||
|
} while(i > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rows & 3) {
|
||||||
|
aoffsets[0] = aoffset;
|
||||||
|
for (int it = 1; it < 4; ++it)
|
||||||
|
aoffsets[it] = aoffsets[it-1] + lda;
|
||||||
|
if (cols == 4) {
|
||||||
|
switch(rows) {
|
||||||
|
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
||||||
|
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
||||||
|
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
vector_permute_store(c_arr, 2, vecOffset);
|
||||||
|
for (int it = 0; it< 4; it++)
|
||||||
|
aoffsets[it] = aoffsets[it] + lda;
|
||||||
|
vecOffset += 32;
|
||||||
|
}
|
||||||
|
i = (cols >> 3);
|
||||||
|
if (i > 0) {
|
||||||
|
do {
|
||||||
|
switch(rows) {
|
||||||
|
case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]);
|
||||||
|
case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]);
|
||||||
|
case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
vector_permute_store(c_arr, 4, vecOffset);
|
||||||
|
for (int it = 0; it <4; it++)
|
||||||
|
aoffsets[it] = aoffsets[it] + 8* lda;
|
||||||
|
vecOffset += 64;
|
||||||
|
i--;
|
||||||
|
} while(i > 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t mc, nc, mp, np;
|
||||||
|
int m_rem = MIN(m - m0, 8);
|
||||||
|
int n_rem = MIN(n - n0, 8);
|
||||||
|
|
||||||
|
if (m_rem >= 8 && n_rem >= 8) {
|
||||||
|
mc = 8;
|
||||||
|
nc = 8;
|
||||||
|
gemm<8,8>(m0, m, n0, n);
|
||||||
|
} else if (m_rem >= 4 && n_rem >= 8) {
|
||||||
|
mc = 4;
|
||||||
|
nc = 8;
|
||||||
|
gemm<4,8>(m0, m, n0, n);
|
||||||
|
} else if (m_rem >=8 && n_rem >=4){
|
||||||
|
mc = 8;
|
||||||
|
nc = 4;
|
||||||
|
gemm<8,4>(m0, m, n0, n);
|
||||||
|
} else if ((m_rem < 4) && (n_rem >= 8)) {
|
||||||
|
nc = 8;
|
||||||
|
switch(m_rem) {
|
||||||
|
case 1:
|
||||||
|
mc = 1;
|
||||||
|
gemm_Mx8<1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
mc = 2;
|
||||||
|
gemm_Mx8<2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
mc = 3;
|
||||||
|
gemm_Mx8<3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if (m_rem >= 4 && n_rem >= 4) {
|
||||||
|
mc = 4;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small<4, 4>(m0, m, n0, n);
|
||||||
|
} else if ((m_rem > 4) && (n_rem < 4)) {
|
||||||
|
mc = 4;
|
||||||
|
switch(n_rem) {
|
||||||
|
case 1:
|
||||||
|
nc = 1;
|
||||||
|
gemm_small<4, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
nc = 2;
|
||||||
|
gemm_small<4, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
nc = 3;
|
||||||
|
gemm_small<4, 3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch((m_rem << 4) | n_rem) {
|
||||||
|
case 0x43:
|
||||||
|
mc = 4;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small<4, 3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x42:
|
||||||
|
mc = 4;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small<4, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x41:
|
||||||
|
mc = 4;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small<4, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x34:
|
||||||
|
mc = 3;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small<3, 4>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x33:
|
||||||
|
mc = 3;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small<3, 3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x32:
|
||||||
|
mc = 3;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small<3, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x31:
|
||||||
|
mc = 3;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small<3, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x24:
|
||||||
|
mc = 2;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small<2,4>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x23:
|
||||||
|
mc = 2;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small<2, 3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x22:
|
||||||
|
mc = 2;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small<2, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x21:
|
||||||
|
mc = 2;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small<2, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x14:
|
||||||
|
mc = 1;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small<1, 4>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x13:
|
||||||
|
mc = 1;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small<1, 3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x12:
|
||||||
|
mc = 1;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small<1, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x11:
|
||||||
|
mc = 1;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small<1, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp = m0 + (m - m0) / mc * mc;
|
||||||
|
np = n0 + (n - n0) / nc * nc;
|
||||||
|
mnpack(mp, m, n0, np);
|
||||||
|
mnpack(m0, m, np, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[4], vec_B[8] , vec_C[4];
|
||||||
|
acc_t acc_0, acc_1;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
for (int l = 0; l < k; l+=8) {
|
||||||
|
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
||||||
|
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
||||||
|
for (int x = 0; x < 4; x++) {
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
SAVE_ACC(&acc_1, ii, jj+4);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[8], vec_B[4] , vec_C[4];
|
||||||
|
acc_t acc_0, acc_1;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
for (int l = 0; l < k; l+=8) {
|
||||||
|
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
||||||
|
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||||
|
for (int x = 0; x < 4; x++) {
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
SAVE_ACC(&acc_1, ii+4, jj);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[8], vec_B[8], vec_C[4];
|
||||||
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_2);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_3);
|
||||||
|
for (int l = 0; l < k; l+=8) {
|
||||||
|
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
||||||
|
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
||||||
|
for (int x = 0; x < 4; x++) {
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
SAVE_ACC(&acc_1, ii, jj+4);
|
||||||
|
SAVE_ACC(&acc_2, ii+4, jj);
|
||||||
|
SAVE_ACC(&acc_3, ii+4, jj+4);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int RM, int RN>
|
||||||
|
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
vec_t vec_C[4];
|
||||||
|
acc_t acc_0;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
vec_t vec_A[2], vec_B[2];
|
||||||
|
for (int l=0; l<k; l+=4) {
|
||||||
|
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
||||||
|
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
||||||
|
for (int x = 0; x<2; x++) {
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||||
|
for (int I = 0; I < RM; I++) {
|
||||||
|
for (int J = 0; J < RN; J++) {
|
||||||
|
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int RM>
|
||||||
|
void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int RN = 8;
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
vec_t vec_C[4];
|
||||||
|
acc_t acc_0, acc_1;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
vec_t vec_A[4], vec_B[8];
|
||||||
|
for (int l=0; l<k; l+=8) {
|
||||||
|
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
||||||
|
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
||||||
|
for (int x = 0; x<4; x++) {
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||||
|
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||||
|
for (int I = 0; I < RM; I++) {
|
||||||
|
for (int J = 0; J < 4; J++) {
|
||||||
|
*((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__builtin_mma_disassemble_acc(vec_C, &acc_1);
|
||||||
|
for (int I = 0; I < RM; I++) {
|
||||||
|
for (int J = 0; J < 4; J++) {
|
||||||
|
*((TC*)(C+ii+((jj+4+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int RM, int RN>
|
||||||
|
inline void kernel(int64_t ii, int64_t jj) {
|
||||||
|
if constexpr(RM == 4 && RN == 8) {
|
||||||
|
KERNEL_4x8(ii,jj);
|
||||||
|
} else if constexpr(RM == 8 && RN == 8) {
|
||||||
|
KERNEL_8x8(ii,jj);
|
||||||
|
} else if constexpr(RM == 8 && RN == 4) {
|
||||||
|
KERNEL_8x4(ii,jj);
|
||||||
|
} else {
|
||||||
|
static_assert(false, "RN/RM values not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int RM, int RN>
|
||||||
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
kernel<RM, RN>(ii, jj);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TA *const A;
|
||||||
|
const TB *const B;
|
||||||
|
TC *C;
|
||||||
|
const int64_t k;
|
||||||
|
const int64_t lda;
|
||||||
|
const int64_t ldb;
|
||||||
|
const int64_t ldc;
|
||||||
|
const int ith;
|
||||||
|
const int nth;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
class tinyBLAS_Q0_PPC {
|
class tinyBLAS_Q0_PPC {
|
||||||
public:
|
public:
|
||||||
@ -2202,6 +2689,7 @@ class tinyBLAS_PPC {
|
|||||||
boffset = vec;
|
boffset = vec;
|
||||||
j = (rows >> 3);
|
j = (rows >> 3);
|
||||||
if (j > 0) {
|
if (j > 0) {
|
||||||
|
|
||||||
do {
|
do {
|
||||||
aoffset1 = aoffset;
|
aoffset1 = aoffset;
|
||||||
aoffset2 = aoffset1 + lda;
|
aoffset2 = aoffset1 + lda;
|
||||||
@ -2875,9 +3363,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|||||||
(float *)C, ldc};
|
(float *)C, ldc};
|
||||||
return tb.matmul(m, n);
|
return tb.matmul(m, n);
|
||||||
}
|
}
|
||||||
|
#elif defined(__MMA__)
|
||||||
|
if ((k % 8))
|
||||||
|
return false;
|
||||||
|
if(Btype == GGML_TYPE_BF16) {
|
||||||
|
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
||||||
|
(const ggml_bf16_t *)A, lda,
|
||||||
|
(const ggml_bf16_t *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
params->ith, params->nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
#if defined(__AVX512F__)
|
#if defined(__AVX512F__)
|
||||||
if (Btype == GGML_TYPE_F16) {
|
if (Btype == GGML_TYPE_F16) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user