sycl: add usage of enqueue_functions extension (llama/14244)

* Add header and namespace to use enqueue_functions extension

* Convert submit and parallel_for to use new extension in convert.cpp

* Convert submit and parallel_for to use extension in ggml-sycl.cpp

* Convert submit and parallel_for to use extension in gla.cpp

* Convert submit and parallel_for in mmq.cpp

* Convert submit and parallel_for in mmvq.cpp

* Convert submit and parallel_for in remaining files

* Convert all simple parallel_for to nd_launch from enqueue_functions
extension

* Wrapping extension in general function

Create a general function that enable the enqueue_functions extension if
it is enable in the compiler, otherwise call the general SYCL function
to launch kernels.

---------

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
This commit is contained in:
Nicolò Scipione
2025-06-20 15:07:21 +02:00
committed by Georgi Gerganov
parent af7168174c
commit a455dcb04c
19 changed files with 750 additions and 986 deletions

View File

@ -225,9 +225,9 @@ struct bin_bcast_sycl {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * stream,
sycl::range<3>(1, 1, block_size), sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
sycl::range<3>(1, 1, block_size)), sycl::range<3>(1, 1, block_size)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>( k_bin_bcast_unravel<bin_op>(
@ -246,9 +246,8 @@ struct bin_bcast_sycl {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13, ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13, s1, s2, s3, s01, s02, s03, s11, s12, s13,

View File

@ -89,32 +89,23 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
sycl::range<3> gridDim(ne2, ne1, num_blocks); sycl::range<3> gridDim(ne2, ne1, num_blocks);
switch (dim) { switch (dim) {
case 0: case 0:
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
});
break; break;
case 1: case 1:
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
});
break; break;
// dim >=2 will be dispatched to the default path // dim >=2 will be dispatched to the default path
default: default:
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
});
break; break;
} }
} }
@ -129,26 +120,22 @@ static void concat_f32_sycl_non_cont(
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2, int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
uint64_t nb3, int32_t dim) { uint64_t nb3, int32_t dim) {
sycl::range<3> gridDim(ne3, ne2, ne1); sycl::range<3> gridDim(ne3, ne2, ne1);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
int64_t i3 = item_ct1.get_group(0); int64_t i3 = item_ct1.get_group(0);
int64_t i2 = item_ct1.get_group(1); int64_t i2 = item_ct1.get_group(1);
int64_t i1 = item_ct1.get_group(2); int64_t i1 = item_ct1.get_group(2);
int64_t o[4] = {0, 0, 0, 0}; int64_t o[4] = { 0, 0, 0, 0 };
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
const float *x; const float * x;
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
i0 += item_ct1.get_local_range(2)) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 + x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
(i0)*nb00);
} else { } else {
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10); (i0 - o[0]) * nb10);
} }
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);

View File

@ -59,15 +59,9 @@ static void conv_transpose_1d_f32_f32_sycl(
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE; const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE); const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
const sycl::range<3> block_nums(1, 1, num_blocks); const sycl::range<3> block_nums(1, 1, num_blocks);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>( conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
block_nums * block_dims, block_dims), item_ct1);
[=](sycl::nd_item<3> item_ct1) {
conv_transpose_1d_kernel(
s0, output_size,
src0_ne0, src0_ne1, src0_ne2,
src1_ne0, dst_ne0,
src0, src1, dst, item_ct1);
}); });
} }

View File

@ -33,14 +33,11 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>( stream,
sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
});
} }
} }
@ -53,24 +50,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
} }
#endif #endif
@ -85,24 +76,18 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
} }
#endif #endif
} }
@ -116,12 +101,9 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_0(vx, y, nb32, item_ct1);
});
} }
} }
@ -135,13 +117,12 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
int constexpr WARP_K = WARP_SIZE * QK4_0; int constexpr WARP_K = WARP_SIZE * QK4_0;
const int n_warp = (k + WARP_K - 1) / WARP_K; const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % 2 == 0); GGML_ASSERT(k % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl_parallel_for(stream,
sycl::range<3>(1, 1, WARP_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)), sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_block_q4_0_reorder(vx, y, k, item_ct1); dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
}); });
} }
template <typename dst_t> template <typename dst_t>
@ -153,12 +134,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_1(vx, y, nb32, item_ct1);
});
} }
} }
@ -171,11 +149,10 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh); sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1); dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
}); });
@ -191,10 +168,10 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh); sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) { [=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
}); });
@ -210,24 +187,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
} }
#endif #endif
@ -242,24 +213,18 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 64), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
sycl::range<3>(1, 1, 64)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
} }
#else #else
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
} }
#endif #endif
@ -271,7 +236,7 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
} }
@ -284,15 +249,10 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_s(
vx, y, item_ct1, iq1s_grid_gpu
);
});
}); });
} }
} }
@ -305,15 +265,10 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_m(
vx, y, item_ct1, iq1s_grid_gpu
);
});
}); });
} }
} }
@ -326,14 +281,11 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xxs( dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
vx, y, item_ct1, iq2xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
}); });
}); });
} }
@ -347,14 +299,11 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xs( dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
vx, y, item_ct1, iq2xs_grid,
ksigns_iq2xs, kmask_iq2xs);
}); });
}); });
} }
@ -368,13 +317,10 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_s(vx, y, item_ct1);
});
}); });
} }
} }
@ -388,14 +334,11 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_xxs( dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
vx, y, item_ct1, iq3xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
}); });
}); });
} }
@ -409,14 +352,10 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl_parallel_for(
sycl::range<3>(1, 1, 32), cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_s(
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
});
}); });
} }
} }
@ -432,14 +371,11 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * cgh,
sycl::range<3>(1, 1, 32), sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_xs(vx, y, item_ct1);
});
}); });
} }
#endif #endif
@ -453,14 +389,11 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * cgh,
sycl::range<3>(1, 1, 32), sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_nl(vx, y, item_ct1);
});
}); });
} }
} }

View File

@ -413,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -431,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -449,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -465,7 +468,7 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0; const int num_blocks = ne / QK8_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -477,7 +480,7 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -490,7 +493,7 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0; const int num_blocks = ne / QK4_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -502,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -516,7 +520,7 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1; const int num_blocks = ne / QK4_1;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -528,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -542,7 +547,7 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0; const int num_blocks = ne / QK5_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -554,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -568,7 +574,7 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1; const int num_blocks = ne / QK5_1;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
@ -580,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne; const int num_blocks = ne;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1); item_ct1);
@ -594,10 +601,10 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL; const int num_blocks = ne / QK4_NL;
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne12, nb10, nb11, nb12, nb13, item_ct1); ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -609,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
{ {
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -628,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
// dpct::has_capability_or_fail(stream->get_device(), // dpct::has_capability_or_fail(stream->get_device(),
// {sycl::aspect::fp16}); // {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -647,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
// dpct::has_capability_or_fail(stream->get_device(), // dpct::has_capability_or_fail(stream->get_device(),
// {sycl::aspect::fp16}); // {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
@ -662,10 +672,12 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -675,10 +687,12 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -689,10 +703,12 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -702,9 +718,12 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }
@ -715,9 +734,12 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
const int nb12, const int nb13, queue_ptr stream) { const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
}); });
} }

View File

@ -208,11 +208,9 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
nrows, item_ct1);
}); });
} }
} }
@ -877,11 +875,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>( dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
vx, y, dst, ncols, nrows, item_ct1); nrows, item_ct1);
}); });
} }
} }
@ -900,11 +897,9 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>( dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -921,11 +916,9 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>( dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -942,11 +935,9 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>( dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -963,11 +954,9 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>( dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -984,11 +973,9 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>( dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
vx, y, dst, ncols, nrows, item_ct1);
}); });
} }
} }
@ -1002,8 +989,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1018,8 +1004,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1034,8 +1019,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
}); });
@ -1047,8 +1031,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
}); });
@ -1063,8 +1046,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
}); });

View File

@ -13,10 +13,10 @@
#ifndef GGML_SYCL_DPCT_HELPER_HPP #ifndef GGML_SYCL_DPCT_HELPER_HPP
#define GGML_SYCL_DPCT_HELPER_HPP #define GGML_SYCL_DPCT_HELPER_HPP
#include <map>
#include <sycl/sycl.hpp> #include <sycl/sycl.hpp>
#include <sycl/half_type.hpp> #include <sycl/half_type.hpp>
#include <syclcompat/math.hpp> #include <syclcompat/math.hpp>
#include <map>
#ifdef GGML_SYCL_USE_INTEL_ONEMKL #ifdef GGML_SYCL_USE_INTEL_ONEMKL
#include <oneapi/mkl.hpp> #include <oneapi/mkl.hpp>
@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
#endif #endif
} }
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
namespace syclex = sycl::ext::oneapi::experimental;
#endif
template <int NR, typename Func>
__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::nd_launch(cgh, nd_range, func);
#else
cgh.parallel_for(nd_range, func);
#endif
}
template <int NR, typename Func>
__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::nd_launch(*q, nd_range, func);
#else
q->parallel_for(nd_range, func);
#endif
}
template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::submit(*stream, func);
#else
stream->submit(func);
#endif
}
namespace dpct namespace dpct
{ {
typedef sycl::queue *queue_ptr; typedef sycl::queue *queue_ptr;

View File

@ -329,13 +329,11 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int ne12, const int nb1, const int nb2, const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) { const int offset, queue_ptr stream) {
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
item_ct1);
}); });
} }
@ -343,46 +341,39 @@ template<typename T>
static void gelu_sycl(const T *x, T *dst, const int k, static void gelu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
gelu(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void silu_sycl(const T *x, T *dst, const int k, static void silu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
silu(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now // hard code for now
const int num_blocks = ceil_div(k, 256); const int num_blocks = ceil_div(k, 256);
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
sgn(x, dst, k, item_ct1); [=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
});
} }
template<typename T> template<typename T>
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now // hard code for now
const int num_blocks = ceil_div(k, 256); const int num_blocks = ceil_div(k, 256);
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { stream,
abs_op(x, dst, k, item_ct1); sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
}); [=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
} }
@ -390,23 +381,20 @@ template<typename T>
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now // hard code for now
const int num_blocks = ceil_div(k, 256); const int num_blocks = ceil_div(k, 256);
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { stream,
elu_op(x, dst, k, item_ct1); sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
}); [=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
} }
template<typename T> template<typename T>
static void gelu_quick_sycl(const T *x, T *dst, const int k, static void gelu_quick_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
gelu_quick(x, dst, k, item_ct1);
});
} }
@ -414,169 +402,133 @@ template<typename T>
static void gelu_erf_sycl(const T *x, T *dst, const int k, static void gelu_erf_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
gelu_erf(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void tanh_sycl(const T *x, T *dst, const int k, static void tanh_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
tanh(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void relu_sycl(const T *x, T *dst, const int k, static void relu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
relu(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void hardsigmoid_sycl(const T *x, T *dst, const int k, static void hardsigmoid_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * stream,
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
hardsigmoid(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void hardswish_sycl(const T *x, T *dst, const int k, static void hardswish_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * stream,
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
hardswish(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void exp_sycl(const T *x, T *dst, const int k, static void exp_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
exp(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void log_sycl(const T *x, T *dst, const int k, static void log_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
log(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void neg_sycl(const T *x, T *dst, const int k, static void neg_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
neg(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void step_sycl(const T *x, T *dst, const int k, static void step_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); });
step(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sigmoid_sycl(const T *x, T *dst, const int k, static void sigmoid_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * stream,
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); });
sigmoid(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sqrt_sycl(const T *x, T *dst, const int k, static void sqrt_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
sqrt(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void sin_sycl(const T *x, T *dst, const int k, static void sin_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
sin(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
static void cos_sycl(const T *x, T *dst, const int k, static void cos_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
cos(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
@ -584,26 +536,20 @@ static void leaky_relu_sycl(const T *x, T *dst, const int k,
const float negative_slope, const float negative_slope,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
leaky_relu(x, dst, k, negative_slope, item_ct1);
});
} }
template<typename T> template<typename T>
static void sqr_sycl(const T *x, T *dst, const int k, static void sqr_sycl(const T *x, T *dst, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); });
sqr(x, dst, k, item_ct1);
});
} }
template<typename T> template<typename T>
@ -614,9 +560,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
int dst_size = ne10 * ne11 * ne12 * ne13; int dst_size = ne10 * ne11 * ne12 * ne13;
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
stream->parallel_for( sycl_parallel_for<1>(
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) {
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
}); });
} }
@ -627,12 +572,10 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
const int ne1, const int ne2, queue_ptr stream) { const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
sycl::range<3> gridDim(ne2, ne1, num_blocks); sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
});
} }
template<typename T> template<typename T>
@ -640,13 +583,10 @@ static void clamp_sycl(const T *x, T *dst, const float min,
const float max, const int k, const float max, const int k,
queue_ptr stream) { queue_ptr stream) {
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
stream->parallel_for( sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
clamp(x, dst, min, max, k, item_ct1);
});
} }
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {

View File

@ -118,11 +118,9 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
GGML_ASSERT(ne00 % 2 == 0); GGML_ASSERT(ne00 % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) { k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
k_get_rows<qk, qr, dq>( item_ct1);
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
}); });
GGML_UNUSED(dst); GGML_UNUSED(dst);
@ -156,9 +154,8 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
}); });

View File

@ -1887,13 +1887,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const size_t shared_mem = ncols_pad * sizeof(int); const size_t shared_mem = ncols_pad * sizeof(int);
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1( sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh); sycl::range<1>(shared_mem), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>( k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1, x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>() dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@ -1901,13 +1900,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
}); });
}); });
} else if (order == GGML_SORT_ORDER_DESC) { } else if (order == GGML_SORT_ORDER_DESC) {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1( sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh); sycl::range<1>(shared_mem), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>( k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1, x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>() dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@ -1925,15 +1923,13 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
const sycl::range<3> block_nums(1, nrows, 1); const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = 256 * sizeof(float); const size_t shared_mem = 256 * sizeof(float);
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_data( sycl::local_accessor<float, 1> shared_data(
sycl::range<1>(shared_mem/sizeof(float)), cgh); sycl::range<1>(shared_mem/sizeof(float)), cgh);
sycl::local_accessor<int, 1> shared_indices( sycl::local_accessor<int, 1> shared_indices(
sycl::range<1>(shared_mem/sizeof(float)), cgh); sycl::range<1>(shared_mem/sizeof(float)), cgh);
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2); const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1); const int row = item_ct1.get_global_id(1);
@ -1952,7 +1948,7 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
shared_indices[tid] = max_idx; shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
for (int stride = 256/2; stride > 0; stride >>= 1) { for (int stride = 256 / 2; stride > 0; stride >>= 1) {
if (tid < stride) { if (tid < stride) {
float val1 = shared_data[tid]; float val1 = shared_data[tid];
float val2 = shared_data[tid + stride]; float val2 = shared_data[tid + stride];
@ -1964,7 +1960,6 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
item_ct1.barrier(sycl::access::fence_space::local_space); item_ct1.barrier(sycl::access::fence_space::local_space);
} }
if (tid == 0) { if (tid == 0) {
dst[row] = shared_indices[0]; dst[row] = shared_indices[0];
} }
@ -2952,7 +2947,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
void ** ptrs_dst_get = ptrs_dst.get(); void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
}); });
@ -3456,7 +3451,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
{ {
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh); sycl::local_accessor<int, 0> src1_row_acc(cgh);
char *__restrict src1_contiguous_get = char *__restrict src1_contiguous_get =
@ -3468,9 +3463,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
size_t ids_nb_ct6 = ids->nb[1]; size_t ids_nb_ct6 = ids->nb[1];
size_t ids_nb_ct7 = ids->nb[0]; size_t ids_nb_ct7 = ids->nb[0];
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous( k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get, src1_original, src1_contiguous_get,
dev_cur_src1_row_get, dev_cur_src1_row_get,
@ -3501,15 +3495,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
{ {
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
sycl::range<3> grid_dims(1, 1, num_src1_rows); sycl::range<3> grid_dims(1, 1, num_src1_rows);
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
const char *__restrict dst_contiguous_get = const char *__restrict dst_contiguous_get =
dst_contiguous.get(); dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get = const mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get(); dev_row_mapping.get();
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original, k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get, dst_contiguous_get,
dev_row_mapping_get, dev_row_mapping_get,

View File

@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
const u_int n_seq_tokens = T / B; const u_int n_seq_tokens = T / B;
sycl::range<1> block_dims((C / H)); sycl::range<1> block_dims((C / H));
sycl::range<1> grid_dims((B * H)); sycl::range<1> grid_dims((B * H));
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
/* local memory accessors*/ /* local memory accessors*/
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh); auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh); auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh); auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
u_int tid = item.get_local_id(0); u_int tid = item.get_local_id(0);
u_int bid = item.get_group(0); u_int bid = item.get_group(0);

View File

@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
const int64_t CHW = IC * KH * KW; const int64_t CHW = IC * KH * KW;
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
p0, p1, d0, d1, item_ct1); p0, p1, d0, d1, item_ct1);
}); });

View File

@ -1818,7 +1818,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
@ -1829,9 +1829,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_0<need_check>( mul_mat_q4_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -1853,7 +1852,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
@ -1864,9 +1863,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_0<need_check>( mul_mat_q4_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -1933,7 +1931,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
@ -1944,9 +1942,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_1<need_check>( mul_mat_q4_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -1968,7 +1965,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
@ -1979,9 +1976,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_1<need_check>( mul_mat_q4_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2048,7 +2044,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
@ -2059,9 +2055,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_0<need_check>( mul_mat_q5_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2083,7 +2078,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
@ -2094,9 +2089,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_0<need_check>( mul_mat_q5_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2163,7 +2157,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
@ -2174,9 +2168,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_1<need_check>( mul_mat_q5_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2198,7 +2191,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
@ -2209,9 +2202,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_1<need_check>( mul_mat_q5_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2278,7 +2270,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
@ -2289,9 +2281,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q8_0<need_check>( mul_mat_q8_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2313,7 +2304,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1( sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1( sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
@ -2324,9 +2315,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q8_0<need_check>( mul_mat_q8_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2393,7 +2383,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
@ -2406,9 +2396,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q2_K<need_check>( mul_mat_q2_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2431,7 +2420,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
@ -2444,9 +2433,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q2_K<need_check>( mul_mat_q2_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2516,7 +2504,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
@ -2531,9 +2519,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q3_K<need_check>( mul_mat_q3_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2557,7 +2544,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
@ -2572,9 +2559,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q3_K<need_check>( mul_mat_q3_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2644,7 +2630,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
@ -2657,9 +2643,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_K<need_check>( mul_mat_q4_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2682,7 +2667,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
@ -2695,9 +2680,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_K<need_check>( mul_mat_q4_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2765,7 +2749,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
@ -2778,9 +2762,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_K<need_check>( mul_mat_q5_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2803,7 +2786,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
@ -2816,9 +2799,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_K<need_check>( mul_mat_q5_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2886,7 +2868,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
@ -2899,9 +2881,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q6_K<need_check>( mul_mat_q6_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,
@ -2924,7 +2905,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1( sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
@ -2937,9 +2918,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1( sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
mul_mat_q6_K<need_check>( mul_mat_q6_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1, nrows_dst, item_ct1,

View File

@ -544,8 +544,8 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows, mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
nd_item); nd_item);
@ -561,8 +561,8 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>( mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
@ -580,15 +580,10 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -604,15 +599,10 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -628,15 +618,10 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -652,15 +637,10 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -676,15 +656,10 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -700,15 +675,10 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -724,15 +694,10 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -750,11 +715,11 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
nrows, nd_item); nd_item);
}); });
}); });
} }
@ -769,15 +734,10 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -794,8 +754,8 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows, mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
nd_item); nd_item);
@ -811,15 +771,10 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
cgh.parallel_for( mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1); vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
@ -836,13 +791,11 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { nrows, item_ct1);
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -857,13 +810,11 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { nrows, item_ct1);
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -878,14 +829,11 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -900,14 +848,11 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
[=](sycl::nd_item<3> item_ct1) nrows, item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -922,14 +867,11 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -944,14 +886,11 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -966,13 +905,11 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { item_ct1);
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -987,14 +924,11 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
[=](sycl::nd_item<3> item_ct1) item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }
@ -1009,14 +943,11 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{ {
sycl_launch(stream, [&](sycl::handler & cgh) {
stream->submit([&](sycl::handler &cgh) { sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
cgh.parallel_for( [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
sycl::nd_range<3>(block_nums * block_dims, block_dims), mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
[=](sycl::nd_item<3> item_ct1) nrows, item_ct1);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
}); });
}); });
} }

View File

@ -254,12 +254,11 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
GGML_ASSERT(ncols % WARP_SIZE == 0); GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) { if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { nullptr, WARP_SIZE);
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
}); });
}); });
} }
@ -272,14 +271,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1( sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
sycl::range<1>(work_group_size / WARP_SIZE), cgh); sycl::range<1>(work_group_size / WARP_SIZE), cgh);
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] { get_pointer(s_sum_acc_ct1), work_group_size);
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
}); });
}); });
} }
@ -290,16 +288,12 @@ static void group_norm_f32_sycl(const float* x, float* dst,
const int ne_elements, queue_ptr stream, int device) { const int ne_elements, queue_ptr stream, int device) {
if (group_size < 1024) { if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
const float eps_ct4 = eps; const float eps_ct4 = eps;
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
block_dims), group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
[=](sycl::nd_item<3> item_ct1) WARP_SIZE);
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE);
}); });
}); });
} }
@ -313,19 +307,15 @@ static void group_norm_f32_sycl(const float* x, float* dst,
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh); cgh);
const float eps_ct4 = eps; const float eps_ct4 = eps;
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
block_dims), group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements,
eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
}); });
@ -340,51 +330,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
const sycl::range<3> global_dims(nsamples, nchannels, nrows); const sycl::range<3> global_dims(nsamples, nchannels, nrows);
if (ncols < 1024) { if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE); const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
[=](sycl::nd_item<3> item_ct1) rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE); nullptr, WARP_SIZE);
}); });
}); });
@ -398,21 +347,53 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh); cgh);
cgh.parallel_for( sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
block_dims), rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size); get_pointer(s_sum_acc_ct1), work_group_size);
}); });
}); });
} }
} }
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size);
/*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
work_group_size);
});
});
}
}
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];

View File

@ -235,9 +235,10 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} else { } else {
/* /*
@ -245,9 +246,10 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
the limit. To get the device limit, query the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed. info::device::max_work_group_size. Adjust the work-group size if needed.
*/ */
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} }
} }
@ -267,14 +269,16 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} else { } else {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, [=](sycl::nd_item<3> item_ct1) {
theta_scale, freq_factors, item_ct1); rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
}); });
} }
} }
@ -298,12 +302,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
} }
// launch kernel // launch kernel
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });
} else { } else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });
@ -333,12 +337,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
} }
// launch kernel // launch kernel
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });
} else { } else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1); corr_dims, theta_scale, freq_factors, sections, item_ct1);
}); });

View File

@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
const int nrows_y, const float scale, const float max_bias, const float m0, const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
const size_t n_local_scratch, queue_ptr stream) { const size_t n_local_scratch, queue_ptr stream) {
stream->submit([&](sycl::handler &cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh); sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims), cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par, soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0, nrows_y, scale, max_bias, m0,

View File

@ -45,13 +45,8 @@ static void timestep_embedding_f32_sycl(
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE; int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE); sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
sycl::range<3> gridDim(1, ne00, num_blocks); sycl::range<3> gridDim(1, ne00, num_blocks);
stream->parallel_for( sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl::nd_range<3>( timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
gridDim * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
timestep_embedding_f32(
x, dst, nb1, dim, max_period, item_ct1
);
}); });
} }

View File

@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel // Submit kernel
if (C / H == WKV_BLOCK_SIZE) { if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>( rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
}); });
}); });
} else { } else {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>( rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel // Submit kernel
if (C / H == WKV_BLOCK_SIZE) { if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>( rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
}); });
}); });
} else { } else {
stream->submit([&](sycl::handler& cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for( sycl_parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims), cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>( rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()