mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-20 17:11:17 +00:00
Another shot at AVX-512 support
This commit is contained in:
parent
01e037c6c6
commit
7fc52fa7ef
2
Makefile
2
Makefile
@ -43,7 +43,7 @@ endif
|
||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
# AVX 512
|
||||
CFLAGS += -mavx512f -mavx512dq -mfma -mf16c
|
||||
CFLAGS += -mavx512f -mfma -mf16c
|
||||
|
||||
# AVX 256
|
||||
#CFLAGS += -mavx -mavx2 -mfma -mf16c
|
||||
|
73
ggml.c
73
ggml.c
@ -568,37 +568,16 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
||||
__m512 x0, x1, x2, x3;
|
||||
__m512 y0, y1, y2, y3;
|
||||
|
||||
__m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi;
|
||||
|
||||
for (int i = 0; i < n64; i += 64) {
|
||||
// TODO: is this the best way to do this?
|
||||
t0lo = _mm256_loadu_ps((const float*)(x + i + 0 ));
|
||||
t0hi = _mm256_loadu_ps((const float*)(x + i + 8 ));
|
||||
t1lo = _mm256_loadu_ps((const float*)(x + i + 16));
|
||||
t1hi = _mm256_loadu_ps((const float*)(x + i + 24));
|
||||
t2lo = _mm256_loadu_ps((const float*)(x + i + 32));
|
||||
t2hi = _mm256_loadu_ps((const float*)(x + i + 40));
|
||||
t3lo = _mm256_loadu_ps((const float*)(x + i + 48));
|
||||
t3hi = _mm256_loadu_ps((const float*)(x + i + 56));
|
||||
x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 )));
|
||||
x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16)));
|
||||
x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32)));
|
||||
x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48)));
|
||||
|
||||
x0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1);
|
||||
x1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1);
|
||||
x2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1);
|
||||
x3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1);
|
||||
|
||||
t0lo = _mm256_loadu_ps((const float*)(y + i + 0 ));
|
||||
t0hi = _mm256_loadu_ps((const float*)(y + i + 8 ));
|
||||
t1lo = _mm256_loadu_ps((const float*)(y + i + 16));
|
||||
t1hi = _mm256_loadu_ps((const float*)(y + i + 24));
|
||||
t2lo = _mm256_loadu_ps((const float*)(y + i + 32));
|
||||
t2hi = _mm256_loadu_ps((const float*)(y + i + 40));
|
||||
t3lo = _mm256_loadu_ps((const float*)(y + i + 48));
|
||||
t3hi = _mm256_loadu_ps((const float*)(y + i + 56));
|
||||
|
||||
y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1);
|
||||
y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1);
|
||||
y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1);
|
||||
y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1);
|
||||
y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 )));
|
||||
y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16)));
|
||||
y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32)));
|
||||
y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48)));
|
||||
|
||||
sum0 = _mm512_fmadd_ps(x0, y0, sum0);
|
||||
sum1 = _mm512_fmadd_ps(x1, y1, sum1);
|
||||
@ -953,36 +932,26 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
|
||||
__m512 x0, x1, x2, x3;
|
||||
__m512 y0, y1, y2, y3;
|
||||
|
||||
__m256 t0lo, t0hi, t1lo, t1hi, t2lo, t2hi, t3lo, t3hi;
|
||||
|
||||
for (int i = 0; i < n64; i += 64) {
|
||||
t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 0 )));
|
||||
t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 8 )));
|
||||
t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 16)));
|
||||
t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 24)));
|
||||
t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 32)));
|
||||
t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 40)));
|
||||
t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 48)));
|
||||
t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(y + i + 56)));
|
||||
x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 )));
|
||||
x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16)));
|
||||
x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32)));
|
||||
x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48)));
|
||||
|
||||
y0 = _mm512_insertf32x8(_mm512_castps256_ps512(t0lo), t0hi, 1);
|
||||
y1 = _mm512_insertf32x8(_mm512_castps256_ps512(t1lo), t1hi, 1);
|
||||
y2 = _mm512_insertf32x8(_mm512_castps256_ps512(t2lo), t2hi, 1);
|
||||
y3 = _mm512_insertf32x8(_mm512_castps256_ps512(t3lo), t3hi, 1);
|
||||
|
||||
t0lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 0 )));
|
||||
t0hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 8 )));
|
||||
t1lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16)));
|
||||
t1hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24)));
|
||||
t2lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 32)));
|
||||
t2hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 40)));
|
||||
t3lo = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 48)));
|
||||
t3hi = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 56)));
|
||||
y0 = _mm512_loadu_ps(y + i + 0 );
|
||||
y1 = _mm512_loadu_ps(y + i + 16);
|
||||
y2 = _mm512_loadu_ps(y + i + 32);
|
||||
y3 = _mm512_loadu_ps(y + i + 48);
|
||||
|
||||
y0 = _mm512_fmadd_ps(x0, v16, y0);
|
||||
y1 = _mm512_fmadd_ps(x1, v16, y1);
|
||||
y2 = _mm512_fmadd_ps(x2, v16, y2);
|
||||
y3 = _mm512_fmadd_ps(x3, v16, y3);
|
||||
|
||||
_mm256_storeu_si256((__m256i*)(y + i + 0 ), _mm512_cvtps_ph(y0, 0));
|
||||
_mm256_storeu_si256((__m256i*)(y + i + 16), _mm512_cvtps_ph(y1, 0));
|
||||
_mm256_storeu_si256((__m256i*)(y + i + 32), _mm512_cvtps_ph(y2, 0));
|
||||
_mm256_storeu_si256((__m256i*)(y + i + 48), _mm512_cvtps_ph(y3, 0));
|
||||
}
|
||||
|
||||
// leftovers
|
||||
|
Loading…
x
Reference in New Issue
Block a user