diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index e9500f3a..abad847c 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -333,8 +333,12 @@ struct ggml_backend_sycl_context { // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -345,6 +349,15 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool() { return pool(device); } + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device, 0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { return host_pool(device); } }; // common device functions diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index e167948e..c96395be 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -82,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } +template struct matrix_info_t { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1727,26 +1735,13 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) - { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, + int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, + int ldb, const void * beta, void ** c, int ldc, int batch_size, + matrix_info_t * matrix_info) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1763,23 +1758,18 @@ namespace dpct sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast(a), - matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), matrix_info->ld_info + 2, 1, - &(matrix_info->groupsize_info)); + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #else sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, - matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), - matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); #endif - - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); } template @@ -2422,25 +2412,11 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, + const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], + library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, + matrix_info_t * matrix_info) { std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2449,48 +2425,24 @@ namespace dpct library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2498,19 +2450,16 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2522,10 +2471,9 @@ namespace dpct dpct::get_value(reinterpret_cast(alpha), q); float beta_float = dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, + matrix_info); break; } case detail::get_type_combination_id( @@ -2533,8 +2481,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2542,8 +2489,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2557,8 +2503,7 @@ namespace dpct sycl::half alpha_half(alpha_value); sycl::half beta_half(beta_value); detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); break; } default: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 5272ca45..ed4d8bb8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1173,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +struct ggml_sycl_pool_host : public ggml_sycl_pool { + queue_ptr qptr; + int device; + + inline static int counter{ 0 }; + + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{ 64 }; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if (counter == MAX_POOL_SIZE) { + ggml_sycl_buffer b = buffer_pool[0]; + void * ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer & b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } else { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { // TBD: NO VMM support // if (ggml_sycl_info().devices[device].vmm) { @@ -3363,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3391,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **)(ptrs_src.get() + 0 * ne23), - dpct::library_data_t::real_half, nb01 / nb00, - (const void **)(ptrs_src.get() + 1 * ne23), - dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type))); + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, + (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); } } catch (sycl::exception const &exc) {