diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 6cbc7e0f..5efe03d3 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -65,6 +65,9 @@ public: dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); +#ifdef GGML_SYCL_F16 + primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); +#endif auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3693b0a4..feb30304 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2127,21 +2127,18 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16 ? (const sycl::half *)src1->data + src1_padded_row_size : src1_as_f16.get(); - ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), - dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); - scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, - " : converting dst to fp32"); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else #endif { + ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); + const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(