ggml-cpu: Support s390x SIMD Instruction Set (llama/12019)

* ggml: add s390x ARCH_FLAGS for compilation

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add SIMD for s390x using vector intrinsics

SIMD is activated for:
* ggml_vec_dot_f32
* ggml_vec_dot_f16
* ggml_vec_mad_f32
* ggml_vec_mad_f16
* ggml_vec_mad_f32_unroll
* ggml_vec_scale_f32
* ggml_vec_scale_f16

SIMD is NOT activated for:
* ggml_vec_dot_f16_unroll (pending bugfix)

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix missing escape character in GGML_F32x4_REDUCE

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add temporary patch for GGML_F32_ARR and GGML_F16_ARR

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix s390x GGML_F32x4_REDUCE

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: full SIMD activation for F32,F16 s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add option to disable s390x VXE/VXE2

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: change vecintrin.h include to ggml-cpu-impl

* add __VXE__ and __VXE2__ macros

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* cmake: add s390x target detection for VX/VXE/VXE2

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: move s390x vector intrinsics to ggml-cpu-impl.h

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x Q8_0 SIMD

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: correct documentation for Q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x reduce code complexity Q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x bugfix typo Q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activated for Q4_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x inline vec_reve

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for Q4_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add VXE backend feature

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: remove test.py

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for quantize_row_q8_0

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for quantize_row_q8_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for iq4_xs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: bugfix iq4_xs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for iq4_nl

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add float, double, and long vector data type

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: clean up iq4_xs SIMD

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix improper use of restrict keyword

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: update warning message for ggml_vec_tbl

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: untested implementation of ggml_vec_dot_iq2_xxs_q8_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: update ggml_vec_dot_q4_1_q8_1 to use typedefs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: switch to restrict for iq4_nl

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: slight dot product speed improvement for q4_1_q8_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for q6_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add missing `_t` to ggml_int8x16x4_t

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix missing `_t` for ggml_vec_xl_s8x4

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix more missing `_t`

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add unroll and prefetch to Q8_0

increase of 3.86% for prompt processing and 32.22% for token generation

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: patch Q8_0 to use proper vector sizes

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: optimise Q8_0 dot prod compute kernel further

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: add unroll and prefetch to Q4_1

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: refactor Q6_K variable naming for readability

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q6_K typos

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for Q5_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix wrong char*x16_t naming

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: Q5_K y0 wrong signness

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q5_K invalid uchar type

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q5_K invalid uchar type

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: s390x SIMD activation for Q4_K

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: fix Q4_K invalid vector intrinsics

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: simplify ggml_padd_s16 compute kernel

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: correct ggml-cpu vxe wording

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: change ggml_aligned_malloc alignment to 256

256 is the cache line size for s390x platforms

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: resolve pr merge via cherry-pick 225bbbf

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml : fix LoongArch compile error with 128-bit SIMD (llama/11701)

* ggml: resolve pr merge via cherry-pick 4571953

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* ggml: cmake remove fork when determining s390x machine type

thank you @ericcurtin

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
Co-authored-by: Jinyang He <hejinyang@loongson.cn>
Co-authored-by: junchao-zhao <68935141+junchao-loongson@users.noreply.github.com>
This commit is contained in:
Aaron Teo 2025-02-23 05:39:24 +08:00 committed by Georgi Gerganov
parent 38ac47cd4d
commit 82e04e7670
8 changed files with 826 additions and 1 deletions

View File

@ -122,6 +122,7 @@ endif()
option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON) option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

View File

@ -99,6 +99,7 @@ extern "C" {
// other // other
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void);
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
GGML_BACKEND_API int ggml_cpu_has_llamafile (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void);

View File

@ -306,6 +306,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_RVV) if (GGML_RVV)
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
endif() endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
message(STATUS "s390x detected")
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15 -mtune=z15)
elseif (${S390X_M} MATCHES "3931")
message(STATUS "z16 target")
list(APPEND ARCH_FLAGS -march=z16 -mtune=z16)
else()
message(STATUS "Unknown target")
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
list(APPEND ARCH_FLAGS -march=native -mtune=native)
endif()
if (GGML_VXE)
list(APPEND ARCH_FLAGS -mvx -mzvector)
endif()
else() else()
message(STATUS "Unknown architecture") message(STATUS "Unknown architecture")
endif() endif()

View File

@ -59,6 +59,15 @@ struct ggml_compute_params {
#endif #endif
#endif #endif
#if defined(__s390x__) && defined(__VEC__)
#ifndef __VXE__
#define __VXE__
#endif
#ifndef __VXE2__
#define __VXE2__
#endif
#endif
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE)
#include <arm_sve.h> #include <arm_sve.h>
#include <sys/prctl.h> #include <sys/prctl.h>
@ -359,6 +368,148 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#endif #endif
#endif #endif
#if defined(__VXE__) || defined(__VXE2__)
#include <vecintrin.h>
#define vec_neg(a) (-(a)) // Vector Negate
#define vec_add(a, b) ((a) + (b)) // Vector Add
#define vec_sub(a, b) ((a) - (b)) // Vector Subtract
#define vec_mul(a, b) ((a) * (b)) // Vector Multiply
#define vec_div(a, b) ((a) / (b)) // Vector Divide
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
#ifndef vec_and
#define vec_and(a, b) ((a) & (b)) // Vector AND
#endif
#ifndef vec_or
#define vec_or(a, b) ((a) | (b)) // Vector OR
#endif
#ifndef vec_xor
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
#endif
typedef signed char char8x16_t __attribute__((vector_size(16)));
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
typedef int8_t int8x16_t __attribute__((vector_size(16)));
typedef int16_t int16x8_t __attribute__((vector_size(16)));
typedef int32_t int32x4_t __attribute__((vector_size(16)));
typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
typedef float float32x4_t __attribute__((vector_size(16)));
typedef double double64x2_t __attribute((vector_size(16)));
typedef signed long long long64x2_t __attribute((vector_size(16)));
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
typedef struct ggml_uint8x16x2_t {
uint8x16_t val[2];
} ggml_uint8x16x2_t;
inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
ggml_uint8x16x2_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
return res;
}
typedef struct ggml_uint8x16x4_t {
uint8x16_t val[4];
} ggml_uint8x16x4_t;
inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
ggml_uint8x16x4_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
res.val[2] = vec_xl(32, ptr);
res.val[3] = vec_xl(48, ptr);
return res;
}
typedef struct ggml_int8x16x4_t {
int8x16_t val[4];
} ggml_int8x16x4_t;
inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
ggml_int8x16x4_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
res.val[2] = vec_xl(32, ptr);
res.val[3] = vec_xl(48, ptr);
return res;
}
typedef struct ggml_int16x8x2_t {
int16x8_t val[2];
} ggml_int16x8x2_t;
inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
ggml_int16x8x2_t res;
res.val[0] = vec_xl( 0, ptr);
res.val[1] = vec_xl(16, ptr);
return res;
}
/*
! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
! or iq4_nl for example implementation.
*/
inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
int8x16_t res;
res[ 0] = a[b[ 0]];
res[ 1] = a[b[ 1]];
res[ 2] = a[b[ 2]];
res[ 3] = a[b[ 3]];
res[ 4] = a[b[ 4]];
res[ 5] = a[b[ 5]];
res[ 6] = a[b[ 6]];
res[ 7] = a[b[ 7]];
res[ 8] = a[b[ 8]];
res[ 9] = a[b[ 9]];
res[10] = a[b[10]];
res[11] = a[b[11]];
res[12] = a[b[12]];
res[13] = a[b[13]];
res[14] = a[b[14]];
res[15] = a[b[15]];
return res;
}
inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
const uchar8x16_t v_maske = { 0, 1, 4, 5, 8, 9, 12, 13,
16, 17, 20, 21, 24, 25, 28, 29 };
const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
const int16x8_t v_abe = vec_perm(a, b, v_maske);
return v_abo + v_abe;
}
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
return acc + (vec_unpackh(p) + vec_unpackl(p));
}
#endif
#if defined(__loongarch_asx) #if defined(__loongarch_asx)
/* float type data load instructions */ /* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(const float val) { static __m128 __lsx_vreplfr2vr_s(const float val) {

View File

@ -1011,6 +1011,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
} }
#elif defined(__VXE__) || defined(__VXE2__)
for (int i = 0; i < nb; i++) {
__vector float srcv [8];
__vector float asrcv[8];
__vector float amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
vec_extract(amaxv[0], 1)),
MAX(vec_extract(amaxv[0], 2),
vec_extract(amaxv[0], 3)));
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f / d : 0.0f;
y[i].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const __vector float v = vec_mul(srcv[j], vec_splats(id));
const __vector int32_t vi = vec_signed(v);
y[i].qs[4*j + 0] = vec_extract(vi, 0);
y[i].qs[4*j + 1] = vec_extract(vi, 1);
y[i].qs[4*j + 2] = vec_extract(vi, 2);
y[i].qs[4*j + 3] = vec_extract(vi, 3);
}
}
#else #else
GGML_UNUSED(nb); GGML_UNUSED(nb);
// scalar // scalar
@ -1337,6 +1369,44 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
__lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
__lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
} }
#elif defined(__VXE__) || defined(__VXE2__)
for (int i = 0; i < nb; i++) {
__vector float srcv [8];
__vector float asrcv[8];
__vector float amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
vec_extract(amaxv[0], 1)),
MAX(vec_extract(amaxv[0], 2),
vec_extract(amaxv[0], 3)));
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f / d : 0.0f;
y[i].d = GGML_FP32_TO_FP16(d);
__vector int32_t acc = vec_splats(0);
for (int j = 0; j < 8; j++) {
const __vector float v = vec_mul(srcv[j], vec_splats(id));
const __vector int32_t vi = vec_signed(v);
y[i].qs[4*j + 0] = vec_extract(vi, 0);
y[i].qs[4*j + 1] = vec_extract(vi, 1);
y[i].qs[4*j + 2] = vec_extract(vi, 2);
y[i].qs[4*j + 3] = vec_extract(vi, 3);
acc = vec_add(acc, vi);
}
y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
}
#else #else
GGML_UNUSED(nb); GGML_UNUSED(nb);
// scalar // scalar
@ -2488,6 +2558,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
} }
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
#elif defined(__VXE__) || defined(__VXE2__)
__vector float acc = vec_splats(0.0f);
const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
for (; ib < nb; ++ib) {
const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
const __vector int8_t v_xls = vec_sub(v_xl, v_s);
const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
__vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
acc = vec_madd(v_xy, v_d, acc);
}
sumf = acc[0] + acc[1] + acc[2] + acc[3];
#endif #endif
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
int sumi0 = 0; int sumi0 = 0;
@ -2781,6 +2882,35 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
} }
sumf = hsum_float_8(acc) + summs; sumf = hsum_float_8(acc) + summs;
#elif defined(__VXE__) || defined(__VXE2__)
float summs = 0;
float32x4_t acc = vec_splats(0.0f);
const uint8x16_t v_m = vec_splat_u8(0x0F);
#pragma GCC unroll 4
for (; ib < nb; ++ib) {
__builtin_prefetch(x[ib].qs, 0, 1);
__builtin_prefetch(y[ib].qs, 0, 1);
summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
const uint8x16_t v_x = vec_xl(0, x[ib].qs);
const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
const float32x4_t v_xy = vec_float(v_xy_);
const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
acc = vec_madd(v_xy, v_d, acc);
}
sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
#endif #endif
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
int sumi0 = 0; int sumi0 = 0;
@ -3915,6 +4045,27 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
} }
sumf = hsum_float_8(acc); sumf = hsum_float_8(acc);
#elif defined(__VXE__) || defined(__VXE2__)
__vector float acc = vec_splats(0.0f);
#pragma GCC unroll 8
for (; ib < nb; ++ib) {
__builtin_prefetch(x[ib].qs, 0, 1);
__builtin_prefetch(y[ib].qs, 0, 1);
const int8x16_t v_xl = vec_xl(0 , x[ib].qs);
const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
const float32x4_t v_xy = vec_float(v_xy_);
const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
acc = vec_madd(v_xy, v_d, acc);
}
sumf = acc[0] + acc[1] + acc[2] + acc[3];
#endif #endif
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
int sumi = 0; int sumi = 0;
@ -6797,6 +6948,77 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
#elif defined(__VXE__) || defined(__VXE2__)
const uint8x16_t v_lm = vec_splat_u8(0x0F);
const int32x4_t v_z = vec_splat_s32(0);
uint8x16_t v_x[2];
int8x16_t v_xl[2];
int8x16_t v_y[2];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
memcpy(utmp, x[i].scales, 12);
uint32x4_t v_mins8 = { 0 };
v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;
const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
const int32x4_t v_mins = v_minso + v_minse;
sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict x0 = x[i].qs;
const int8_t * restrict y0 = y[i].qs;
int32_t sumi1 = 0;
int32_t sumi2 = 0;
for (int j = 0; j < QK_K/64; ++j) {
v_x[0] = vec_xl(0 , x0);
v_x[1] = vec_xl(16, x0);
x0 += 32;
v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0);
y0 += 32;
v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0);
y0 += 32;
v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
}
sumf += d * (sumi1 + sumi2);
}
*s = sumf;
#else #else
const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * scales = (const uint8_t*)&utmp[0];
@ -7526,7 +7748,94 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
#elif defined(__VXE__) || defined(__VXE2__)
const uint8x16_t v_lm = vec_splat_u8(0x0F);
const uint8x16_t v_1m = vec_splat_u8(0x01);
const uint8x16_t v_2m = vec_splat_u8(0x02);
const int32x4_t v_z = vec_splat_s32(0);
const uchar8x16_t v_minsm = {
0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
};
int8x16_t q5b[4];
uint8x16_t q5h[4];
uint8x16_t v_xl[2];
uint8x16_t v_xh[2];
int8x16_t v_y[4];
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict x0l = x[i].qs;
const uint8_t * restrict x0h = x[i].qh;
const int8_t * restrict y0 = y[i].qs;
v_xh[0] = vec_xl(0 , x0h);
v_xh[1] = vec_xl(16, x0h);
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
v_xl[0] = vec_xl(0 , x0l);
v_xl[1] = vec_xl(16, x0l);
x0l += 32;
v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0);
v_y[2] = vec_xl(32, y0);
v_y[3] = vec_xl(48, y0);
y0 += 64;
q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
v_xh[0] = vec_sr(v_xh[0], 2);
v_xh[1] = vec_sr(v_xh[1], 2);
q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
}
sumf += d * sumi - dmin * mins;
}
*s = sumf;
#else #else
const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * scales = (const uint8_t*)&utmp[0];
@ -8243,7 +8552,130 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
} }
*s = hsum_float_8(acc); *s = hsum_float_8(acc);
#elif defined(__VXE__) || defined(__VXE2__)
float sum = 0;
// Lower 4-bit and upper 2-bit masks
const uint8x16_t v_lm = vec_splat_u8(0x0F);
const uint8x16_t v_um = vec_splat_u8(0x03);
const int32x4_t v_z = vec_splat_s32(0);
int8x16_t q6b[4];
uint8x16_t q6h[4];
uint8x16_t v_xl[4];
uint8x16_t v_xh[2];
int8x16_t v_y[4];
for (int i = 0; i < nb; ++i) {
const float d_all = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * restrict x0l = x[i].ql;
const uint8_t * restrict x0h = x[i].qh;
const int8_t * restrict y0 = y[i].qs;
const int8_t * restrict scale = x[i].scales;
const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
const int8x16_t v_scale = vec_xl(0, scale);
const int16x8_t v_scalel = vec_unpackh(v_scale);
const int16x8_t v_scaleh = vec_unpackl(v_scale);
const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
int32_t isum = 0;
for (int j = 0; j < QK_K/128; ++j) {
// Load model upper 2 bits
v_xh[0] = vec_xl(0 , x0h);
v_xh[1] = vec_xl(16, x0h);
x0h += 32;
// Load model lower 4 bits
v_xl[0] = vec_xl(0 , x0l);
v_xl[1] = vec_xl(16, x0l);
v_xl[2] = vec_xl(32, x0l);
v_xl[3] = vec_xl(48, x0l);
x0l += 64;
// Load activation quants
v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0);
v_y[2] = vec_xl(32, y0);
v_y[3] = vec_xl(48, y0);
y0 += 64;
q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
uint8x16_t shifted = vec_sr(v_xh[0], 2);
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
shifted = vec_sr(v_xh[1], 2);
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
scale += 4;
// Load activation quants
v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0);
v_y[2] = vec_xl(32, y0);
v_y[3] = vec_xl(48, y0);
y0 += 64;
shifted = vec_sr(v_xh[0], 4);
q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
shifted = vec_sr(v_xh[1], 4);
q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
shifted = vec_sr(v_xh[0], 6);
q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
shifted = vec_sr(v_xh[1], 6);
q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
scale += 4;
}
sum += d_all * y[i].d * (isum - 32 * mins);
}
*s = sum;
#else #else
int8_t aux8[QK_K]; int8_t aux8[QK_K];
@ -8604,7 +9036,57 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
} }
*s = 0.125f * hsum_float_8(accumf); *s = 0.125f * hsum_float_8(accumf);
//#elif defined(__VXE__) || defined(__VXE2__)
// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
//
// uint32_t aux32[4];
// const uint8_t * aux8 = (const uint8_t *)aux32;
//
// float sumf = 0;
//
// for (int i = 0; i < nb; ++i) {
// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
// const uint16_t * restrict q2 = x[i].qs;
// const int8_t * restrict q8 = y[i].qs;
//
// float sumf1 = 0, sumf2 = 0;
//
// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
// int8x16_t q8b0 = vec_xl( 0, q8);
// int8x16_t qb81 = vec_xl(16, q8);
// int8x16_t q8b2 = vec_xl(32, q8);
// int8x16_t q8b3 = vec_xl(48, q8);
// q8 += 64;
//
// memcpy(aux32, q2, 4 * sizeof(uint32_t));
// q2 += 8;
//
// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
//
// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) };
// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) };
// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
//
// q2u0 = vec_mul(q2u0, q2s0);
// q2u1 = vec_mul(q2u1, q2s1);
// q2u2 = vec_mul(q2u2, q2s2);
// q2u3 = vec_mul(q2u3, q2s3);
//
// const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
// const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
//
// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
// }
//
// sumf += d * (sumf1 + sumf2);
// }
//
// *s = 0.25f * sumf;
#else #else
uint32_t aux32[2]; uint32_t aux32[2];
@ -11365,6 +11847,27 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
#elif defined(__VXE__) || defined(__VXE2__)
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
const uint8x16_t v_m = vec_splat_u8(0x0F);
for (; ib < nb; ++ib) {
const block_iq4_nl * restrict x0 = &x[ib];
const block_q8_0 * restrict y0 = &y[ib];
const uint8x16_t v_x = vec_xl(0, x0->qs);
int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
const int8x16_t v_yl = vec_xl(0 , y0->qs);
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
}
#endif #endif
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
@ -11643,6 +12146,56 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
} }
*s = hsum_float_8(accum); *s = hsum_float_8(accum);
#elif defined(__VXE__) || defined(__VXE2__)
const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
const uint8x16_t v_m = vec_splat_u8(0x0F);
float sumf = 0;
for (int ibl = 0; ibl < nb; ++ibl) {
const uint8_t * restrict q4 = x[ibl].qs;
const int8_t * restrict q8 = y[ibl].qs;
uint16_t h = x[ibl].scales_h;
int sumi1 = 0, sumi2 = 0;
for (int ib = 0; ib < QK_K/64; ++ib) {
const uint8x16_t v_x0 = vec_xl(0 , q4);
const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
q4 += 32;
int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
const int8x16_t v_y0 = vec_xl( 0, q8);
const int8x16_t v_y1 = vec_xl(16, q8);
const int8x16_t v_y2 = vec_xl(32, q8);
const int8x16_t v_y3 = vec_xl(48, q8);
q8 += 64;
int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
h >>= 4;
sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
}
sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
}
*s = sumf;
#else #else
float sumf = 0; float sumf = 0;

View File

@ -237,6 +237,8 @@ typedef pthread_t ggml_thread_t;
#else #else
#if defined(__POWER9_VECTOR__) #if defined(__POWER9_VECTOR__)
#define CACHE_LINE_SIZE 128 #define CACHE_LINE_SIZE 128
#elif defined(__VXE__) || defined(__VXE2__)
#define CACHE_LINE_SIZE 256
#else #else
#define CACHE_LINE_SIZE 64 #define CACHE_LINE_SIZE 64
#endif #endif
@ -1211,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE
#elif defined(__VXE__) || defined(__VXE2__)
#define GGML_SIMD
// F32 s390x
#define GGML_F32_STEP 32
#define GGML_F32_EPR 4
#define GGML_F32x4 __vector float
#define GGML_F32x4_ZERO vec_splats(0.0f)
#define GGML_F32x4_SET1 vec_splats
#define GGML_F32x4_LOAD(p) vec_xl(0, p)
#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p)
#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
#define GGML_F32x4_ADD vec_add
#define GGML_F32x4_MUL vec_mul
#define GGML_F32x4_REDUCE(res, x) \
{ \
int offset = GGML_F32_ARR >> 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
offset >>= 1; \
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
res = vec_extract(x[0], 0) + \
vec_extract(x[0], 1) + \
vec_extract(x[0], 2) + \
vec_extract(x[0], 3); \
}
#define GGML_F32_VEC GGML_F32x4
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F32_VEC_SET1 GGML_F32x4_SET1
#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD
#define GGML_F32_VEC_STORE GGML_F32x4_STORE
#define GGML_F32_VEC_FMA GGML_F32x4_FMA
#define GGML_F32_VEC_ADD GGML_F32x4_ADD
#define GGML_F32_VEC_MUL GGML_F32x4_MUL
#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
// F16 s390x
#define GGML_F16_STEP GGML_F32_STEP
#define GGML_F16_EPR GGML_F32_EPR
static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
float tmp[4];
for (int i = 0; i < 4; i++) {
tmp[i] = GGML_FP16_TO_FP32(x[i]);
}
return vec_xl(0, tmp);
}
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
float arr[4];
vec_xst(y, 0, arr);
for (int i = 0; i < 4; i++) {
x[i] = GGML_FP32_TO_FP16(arr[i]);
}
}
#define GGML_F16_VEC GGML_F32x4
#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO
#define GGML_F16_VEC_SET1 GGML_F32x4_SET1
#define GGML_F16_VEC_LOAD(p, i) __lzs_f16cx4_load(p)
#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
#define GGML_F16_VEC_FMA GGML_F32x4_FMA
#define GGML_F16_VEC_ADD GGML_F32x4_ADD
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
#endif #endif
// GGML_F32_ARR / GGML_F16_ARR // GGML_F32_ARR / GGML_F16_ARR
@ -14419,6 +14502,14 @@ int ggml_cpu_has_vsx(void) {
#endif #endif
} }
int ggml_cpu_has_vxe(void) {
#if defined(__VXE__) || defined(__VXE2__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_neon(void) { int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON) #if defined(__ARM_ARCH) && defined(__ARM_NEON)
return ggml_arm_arch_features.has_neon; return ggml_arm_arch_features.has_neon;

View File

@ -557,6 +557,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
if (ggml_cpu_has_vsx()) { if (ggml_cpu_has_vsx()) {
features.push_back({ "VSX", "1" }); features.push_back({ "VSX", "1" });
} }
if (ggml_cpu_has_vxe()) {
features.push_back({ "VXE", "1" });
}
if (ggml_cpu_has_wasm_simd()) { if (ggml_cpu_has_wasm_simd()) {
features.push_back({ "WASM_SIMD", "1" }); features.push_back({ "WASM_SIMD", "1" });
} }

View File

@ -240,7 +240,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
void * ggml_aligned_malloc(size_t size) { void * ggml_aligned_malloc(size_t size) {
#if defined(__s390x__)
const int alignment = 256;
#else
const int alignment = 64; const int alignment = 64;
#endif
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
return _aligned_malloc(size, alignment); return _aligned_malloc(size, alignment);