From 5b854ebba53bdff32961e50a60548c85e680ad7d Mon Sep 17 00:00:00 2001 From: Svetlozar Georgiev <55534064+sgeor255@users.noreply.github.com> Date: Fri, 21 Mar 2025 02:15:56 +0000 Subject: [PATCH] sycl: cleanup oneDNN related code (llama/12097) --- ggml/src/ggml-sycl/CMakeLists.txt | 44 ++++++++++++++++------- ggml/src/ggml-sycl/common.hpp | 28 ++++++++++++++- ggml/src/ggml-sycl/gemm.hpp | 59 ++++++++----------------------- ggml/src/ggml-sycl/ggml-sycl.cpp | 12 +++---- 4 files changed, 79 insertions(+), 64 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 271413ca..f713fbe4 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -23,6 +23,38 @@ ggml_add_backend_library(ggml-sycl ../../include/ggml-sycl.h ) +find_package(DNNL) +set(GGML_SYCL_DNNL 0) +if(DNNL_FOUND) + if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR) + # Assuming oneDNN packaged with oneapi release is used which + # supports only intel target + set(DNNL_GPU_VENDOR "INTEL") + if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") + message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") + endif() + endif() + + # Verify oneDNN was compiled for the same target as llama + if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") + target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) + set(GGML_SYCL_DNNL 1) + get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) + foreach(CONFIG ${CONFIGS}) + get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) + message(STATUS "Found oneDNN: ${DNNL_LIB}") + endforeach() + else() + message(WARNING + "oneDNN must be compiled for the same target as llama.cpp. + llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. + Disabling oneDNN support.") + endif() +else() + message(STATUS "oneDNN not found, disabling oneDNN support") +endif() +target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) + if (GGML_SYCL_F16) if (GGML_SYCL_TARGET STREQUAL "AMD") message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.") @@ -48,18 +80,6 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) -find_package(DNNL) -message("-- DNNL found:" ${DNNL_FOUND}) - -if (GGML_SYCL_TARGET STREQUAL "INTEL") - add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) -else() - add_compile_definitions(GGML_SYCL_DNNL=0) -endif() - -if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") - target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) -endif() if (WIN32) find_package(IntelSYCL REQUIRED) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 7cc5e14f..27b447ce 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0; int get_current_device_id(); inline dpct::err0 ggml_sycl_set_device(const int device) try { - int current_device_id; SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id())); @@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc { } } + T * realloc(size_t size) { + GGML_ASSERT(pool != nullptr); + if (ptr) + pool->free(ptr, actual_size); + ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size); + return ptr; + } + // size is in number of elements T * alloc(size_t size) { GGML_ASSERT(pool != nullptr); @@ -371,10 +378,29 @@ struct ggml_backend_sycl_context { dnnl::stream stream_dnnl() { return stream_dnnl(device, 0); } + dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md, + const dnnl::engine & eng, const queue_ptr q) { + ggml_sycl_pool_alloc * pool; + auto it = scratchpad_map.find(q); + if (it == scratchpad_map.end()) { + scratchpad_map[q] = std::make_unique>(this->pool()); + pool = scratchpad_map[q].get(); + } else { + pool = it->second.get(); + } + + size_t scratchpad_size = scratchpad_md.get_size(); + if (scratchpad_size > pool->actual_size) { + pool->realloc(scratchpad_size); + } + void * mem_ptr = pool->get(); + return dnnl::memory(scratchpad_md, eng, mem_ptr); + } #endif // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unordered_map>> scratchpad_map; std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 3f0f34ad..4ebbb5b6 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -13,9 +13,6 @@ #ifndef GGML_SYCL_GEMM_HPP #define GGML_SYCL_GEMM_HPP -#include -#include - #include "ggml-sycl.h" #if GGML_SYCL_DNNL @@ -35,62 +32,34 @@ public: else static_assert(0); } - static inline void row_gemm(sycl::queue& q, bool a_trans, - bool b_trans, int m, int n, int k, - const void* a, dt at, const void* b, dt bt, void* c, dt ct) - { - // Get the device associated with the queue - sycl::device dev = q.get_device(); - // Get the context associated with the queue - sycl::context ctx = q.get_context(); - const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx); - const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q); + static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + auto stream = ctx.stream_dnnl(q); + auto eng = ctx.engine_dnnl(q); dnnl::memory::dims a_dims = { m, k }; dnnl::memory::dims b_dims = { k, n }; dnnl::memory::dims c_dims = { m, n }; const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + + dnnl::primitive_attr primitive_attr; + primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); - auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); + auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr); auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); - // Create the primitive. + auto scratchpad_md = matmul_pd.scratchpad_desc(); + auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); auto matmul_prim = dnnl::matmul(matmul_pd); - // Primitive arguments. - std::unordered_map matmul_args; - matmul_args.insert({ DNNL_ARG_SRC, a_mem }); - matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); - matmul_args.insert({ DNNL_ARG_DST, c_mem }); - - matmul_prim.execute(stream, matmul_args); - } - - - static inline void row_gemm(const dnnl::stream& stream, bool a_trans, - bool b_trans, int m, int n, int k, - const void* a, dt at, const void* b, dt bt, void* c, dt ct) - { - auto const eng = stream.get_engine(); - dnnl::memory::dims a_dims = { m, k }; - dnnl::memory::dims b_dims = { k, n }; - dnnl::memory::dims c_dims = { m, n }; - const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); - auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); - auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); - auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); - auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); - - // Create the primitive. - auto matmul_prim = dnnl::matmul(matmul_pd); - // Primitive arguments. + std::unordered_map matmul_args; matmul_args.insert({ DNNL_ARG_SRC, a_mem }); matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); matmul_args.insert({ DNNL_ARG_DST, c_mem }); + matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); matmul_prim.execute(stream, matmul_args); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 360e3f16..f4b68333 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2058,9 +2058,9 @@ inline void ggml_sycl_op_mul_mat_sycl( const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); #else - auto dnnl_stream = ctx.stream_dnnl(stream); - DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(), - src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt()); + DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr, + DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); #endif @@ -2099,9 +2099,9 @@ inline void ggml_sycl_op_mul_mat_sycl( dst_dd_i, ldc))); # endif #else - auto dnnl_stream = ctx.stream_dnnl(stream); - DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), - src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); + DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, + DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); #endif } GGML_UNUSED(dst);