diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e7304947..6977b705 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -41,6 +41,7 @@ #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" +#include "ggml.h" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; @@ -3864,7 +3865,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32); default: return false; } @@ -3981,23 +3982,24 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + return true; case GGML_OP_ADD: case GGML_OP_ADD1: - case GGML_OP_LOG: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return true; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_GROUP_NORM: - return ggml_is_contiguous(op->src[0]); - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + case GGML_OP_LOG: + return (op->src[0]->type == GGML_TYPE_F32); + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_SCALE: return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16;