SYCL: Refactor and enable FP16 in binary broadcast OPs (llama/12975)

* SYCL: refactor move to a separate file

* Fix binbcast

* Remove duplicates

* fix include formatting

* fix typo
This commit is contained in:
Akarshan Biswas
2025-04-18 19:27:56 +05:30
committed by Georgi Gerganov
parent 24d29c55df
commit 0287a5c51b
7 changed files with 393 additions and 372 deletions

View File

@ -1967,11 +1967,6 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);
}
inline void ggml_sycl_op_mul_mat_sycl(
ggml_backend_sycl_context & ctx,
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@ -2600,12 +2595,6 @@ catch (sycl::exception const &exc) {
}
static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_repeat(ctx, dst);
GGML_SYCL_DEBUG("call %s done\n", __func__);
}
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_SYCL_DEBUG("call %s\n", __func__);
ggml_sycl_op_get_rows(ctx, dst);
@ -3972,7 +3961,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ARGMAX:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_REPEAT:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
@ -3982,7 +3970,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return (op->src[0]->type == GGML_TYPE_F32);
case GGML_OP_REPEAT:
return true;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN: