llamafile : ppc64le GEMV forwarding for FP32. (llama/12594)

This patch enables usage of MMA when one of the
dimensions of the matrix(ie either M or N) is 1. This
is useful in case of token generation where N < 2.

The concept of 'GEMV Forwarding' is used where when one
of the matrix has a single row/column, the elements are
broadcasted, instead of using packing routine to prepack
the matrix elements.

This change results in 5% - 15% improvement in total
speed(ie all tokens/total time), across various batch
sizes. This is in comparision with the corresponding
dot product implementation.

The patch is tested with FP32 models of Meta-Lllama-3-8B,
Mistral-7B, Llama-2-7B-chat-hf on a IBM POWER10 machine.

Signed-off-by: Amrita H S <amritahs@linux.vnet.ibm.com>
This commit is contained in:
amritahs-ibm 2025-03-28 13:13:22 +05:30 committed by Georgi Gerganov
parent 5bad2e5099
commit 0001ec075f

@ -2680,13 +2680,25 @@ class tinyBLAS_PPC {
__builtin_mma_xxsetaccz(&acc_0);
vec_t vec_A[4] {0}, vec_B[4] = {0};
for (int l=0; l<k; l+=4) {
if (RN >= 4 && RM == 1) {
/* 'GEMV Forwarding' concept is used in first two conditional loops.
* when one of the matrix has a single row/column, the elements are
* broadcasted, instead of using packing routine to prepack the
* matrix elements.
*/
if (RM == 1) {
TA* a = const_cast<TA*>(A+(ii)*lda+l);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
vec_A[0] = (vec_t)vec_xl(0,a);
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
} else if (RN == 1) {
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
vec_B[0] = (vec_t)vec_xl(0,b);
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
} else {
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
@ -2790,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
assert(params->ith < params->nth);
// only enable sgemm for prompt processing
#if !defined(__MMA__)
if (n < 2)
return false;
#endif
if (Ctype != GGML_TYPE_F32)
return false;