mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-09 04:15:15 +00:00
fixed compilation warnings in ggml-sycl (llama/12424)
This commit is contained in:
parent
52c4c03b0a
commit
6c15539c54
@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||
sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||
});
|
||||
|
||||
|
@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
default:
|
||||
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "common.hpp"
|
||||
#include "element_wise.hpp"
|
||||
|
||||
void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@ -20,7 +20,7 @@ void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
}
|
||||
}
|
||||
|
||||
void gelu_f32(const float * x, float * dst, const int k,
|
||||
static void gelu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
@ -37,7 +37,7 @@ void gelu_f32(const float * x, float * dst, const int k,
|
||||
sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
|
||||
}
|
||||
|
||||
void silu_f32(const float * x, float * dst, const int k,
|
||||
static void silu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -48,7 +48,7 @@ void silu_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
|
||||
}
|
||||
|
||||
void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
static void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
@ -59,7 +59,7 @@ void gelu_quick_f32(const float *x, float *dst, int k,
|
||||
dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
|
||||
}
|
||||
|
||||
void tanh_f32(const float *x, float *dst, int k,
|
||||
static void tanh_f32(const float *x, float *dst, int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -69,7 +69,7 @@ void tanh_f32(const float *x, float *dst, int k,
|
||||
dst[i] = sycl::tanh((float)(x[i]));
|
||||
}
|
||||
|
||||
void relu_f32(const float * x, float * dst, const int k,
|
||||
static void relu_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -80,7 +80,7 @@ void relu_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::fmax((float)(x[i]), (float)0);
|
||||
}
|
||||
|
||||
void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
static void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -91,7 +91,7 @@ void sigmoid_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
|
||||
}
|
||||
|
||||
void sqrt_f32(const float * x, float * dst, const int k,
|
||||
static void sqrt_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -102,7 +102,7 @@ void sqrt_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::sqrt(x[i]);
|
||||
}
|
||||
|
||||
void sin_f32(const float * x, float * dst, const int k,
|
||||
static void sin_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -113,7 +113,7 @@ void sin_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::sin(x[i]);
|
||||
}
|
||||
|
||||
void cos_f32(const float * x, float * dst, const int k,
|
||||
static void cos_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -124,7 +124,7 @@ void cos_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::cos(x[i]);
|
||||
}
|
||||
|
||||
void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
static void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -135,7 +135,7 @@ void hardsigmoid_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
void hardswish_f32(const float * x, float * dst, const int k,
|
||||
static void hardswish_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -146,7 +146,7 @@ void hardswish_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
void exp_f32(const float * x, float * dst, const int k,
|
||||
static void exp_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -157,7 +157,7 @@ void exp_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = sycl::exp(x[i]);
|
||||
}
|
||||
|
||||
void log_f32(const float * x, float * dst, const int k,
|
||||
static void log_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -173,7 +173,7 @@ void log_f32(const float * x, float * dst, const int k,
|
||||
}
|
||||
}
|
||||
|
||||
void neg_f32(const float * x, float * dst, const int k,
|
||||
static void neg_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -184,7 +184,7 @@ void neg_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = -x[i];
|
||||
}
|
||||
|
||||
void step_f32(const float * x, float * dst, const int k,
|
||||
static void step_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -195,7 +195,7 @@ void step_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] > 0.0f;
|
||||
}
|
||||
|
||||
void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||
static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -206,7 +206,7 @@ void leaky_relu_f32(const float *x, float *dst, const int k, const float negativ
|
||||
sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
|
||||
}
|
||||
|
||||
void sqr_f32(const float * x, float * dst, const int k,
|
||||
static void sqr_f32(const float * x, float * dst, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
@ -217,7 +217,7 @@ void sqr_f32(const float * x, float * dst, const int k,
|
||||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
|
||||
@ -240,7 +240,7 @@ void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
@ -262,7 +262,7 @@ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const i
|
||||
|
||||
|
||||
|
||||
void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
const int offset, queue_ptr stream) {
|
||||
@ -277,7 +277,7 @@ void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -289,7 +289,7 @@ void gelu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -301,7 +301,7 @@ void silu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -313,7 +313,7 @@ void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -325,7 +325,7 @@ void tanh_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -337,7 +337,7 @@ void relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -349,7 +349,7 @@ void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -361,7 +361,7 @@ void hardswish_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -373,7 +373,7 @@ void exp_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -385,7 +385,7 @@ void log_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -397,7 +397,7 @@ void neg_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -409,7 +409,7 @@ void step_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -421,7 +421,7 @@ void sigmoid_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -433,7 +433,7 @@ void sqrt_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -445,7 +445,7 @@ void sin_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -457,7 +457,7 @@ void cos_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
const float negative_slope,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
@ -470,7 +470,7 @@ void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
@ -482,7 +482,7 @@ void sqr_f32_sycl(const float *x, float *dst, const int k,
|
||||
});
|
||||
}
|
||||
|
||||
void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
const float sf2, const float sf3, queue_ptr stream) {
|
||||
@ -496,7 +496,7 @@ void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01
|
||||
});
|
||||
}
|
||||
|
||||
void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
|
@ -207,7 +207,7 @@ static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_te
|
||||
const size_t nrows = ne01;
|
||||
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
k_get_rows_reorder<qk, qr, dq_reorder>(
|
||||
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
@ -302,7 +302,6 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
||||
// TODO: k-quants
|
||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,7 +95,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
|
||||
return info;
|
||||
}
|
||||
|
||||
void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||
static void print_device_detail(int id, sycl::device &device, std::string device_type) {
|
||||
|
||||
dpct::device_info prop;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
@ -118,7 +118,7 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||
}
|
||||
|
||||
void print_device_opt_feature(int device_count) {
|
||||
static void print_device_opt_feature(int device_count) {
|
||||
GGML_LOG_INFO("SYCL Optimization Feature:\n");
|
||||
GGML_LOG_INFO(
|
||||
"|ID| Device Type|Reorder|\n");
|
||||
@ -401,7 +401,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
const void *ptr_src, size_t size) {
|
||||
char *host_buf = (char *)malloc(size);
|
||||
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
|
||||
@ -620,7 +620,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
|
||||
return &ggml_backend_sycl_buffer_types[device];
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||
static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
|
||||
|
||||
int device = ctx->device;
|
||||
@ -1682,7 +1682,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(num_blocks * block_size, block_size),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -1703,7 +1703,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
|
||||
nchannels_y, item_ct1);
|
||||
});
|
||||
@ -1723,7 +1723,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
|
||||
row_stride_x, channel_stride_x,
|
||||
nchannels_y / nchannels_x, item_ct1);
|
||||
@ -1764,7 +1764,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
k_sum_rows_f32(x, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
@ -2920,7 +2920,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
@ -3293,7 +3293,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_set_main_device(const int main_device) try {
|
||||
static void ggml_sycl_set_main_device(const int main_device) try {
|
||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||
return;
|
||||
}
|
||||
@ -3314,7 +3314,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
if (!g_sycl_loaded) return false;
|
||||
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
||||
@ -3638,7 +3638,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
static void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
|
||||
SYCL_CHECK(
|
||||
@ -3652,7 +3652,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
|
||||
stream->parallel_for(
|
||||
size / sizeof(block_q4_0),
|
||||
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
@ -3666,7 +3666,7 @@ void reorder_qw(char *data_device, const int ncols, const int nrows,
|
||||
sycl::free(tmp_buf, *stream);
|
||||
}
|
||||
|
||||
void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
char*data_device = (char*)src0->data;
|
||||
size_t ncols = src0->ne[0];
|
||||
size_t nrows = src0->ne[1];
|
||||
@ -3675,7 +3675,7 @@ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
reorder_qw(data_device, ncols, nrows, size, 0, stream);
|
||||
}
|
||||
|
||||
void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
ggml_tensor *src0 = dst->src[0];
|
||||
ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
@ -3688,7 +3688,7 @@ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
|
||||
}
|
||||
}
|
||||
|
||||
void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||
static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
|
||||
dpct::queue_ptr stream = ctx->stream();
|
||||
if (ctx->optimized_graph) {
|
||||
return;
|
||||
@ -3878,7 +3878,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
@ -3896,7 +3896,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
@ -3927,7 +3926,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_OUT_PROD:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
|
||||
case GGML_OP_GET_ROWS:
|
||||
@ -3944,7 +3943,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@ -3995,12 +3994,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
}
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_NONE:
|
||||
|
@ -3017,7 +3017,6 @@ void ggml_sycl_op_mul_mat_q(
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
|
@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0,
|
||||
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
@ -734,7 +734,7 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -755,7 +755,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -777,7 +777,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -799,7 +799,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -821,7 +821,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -843,7 +843,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -864,7 +864,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -886,7 +886,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -908,7 +908,7 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
@ -1003,7 +1003,6 @@ void ggml_sycl_op_mul_mat_vec_q(
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
@ -235,7 +235,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
@ -258,7 +258,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
@ -277,7 +277,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(
|
||||
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
@ -304,7 +304,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements,
|
||||
eps_ct4, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
@ -325,7 +325,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
@ -347,7 +347,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
|
@ -132,7 +132,7 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
|
Loading…
x
Reference in New Issue
Block a user