mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-22 16:38:58 +00:00
SYCL: Rename oneMKL to oneMath (llama/12192)
* Rename oneMKL Interface to oneMath * Use oneMath for Intel vendor * Rename occurences to mkl * clang-format * Silence verbose warnings * Set oneMath HIP_TARGETS * Fix silence warnings * Remove step to build oneMath from build instructions * Use fixed oneMath version * Remove INTEL_CPU * Fold CMake oneDNN conditions * Use Intel oneMKL for Intel devices * Improve CMake message * Link against MKL::MKL_SYCL::BLAS only * Move oneMath documentation to Nvidia and AMD sections
This commit is contained in:
committed by
Georgi Gerganov
parent
0d42097fd3
commit
ddf7e6a15d
@ -2059,8 +2059,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
const sycl::half alpha_f16 = 1.0f;
|
||||
const sycl::half beta_f16 = 0.0f;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||
*stream, oneapi::mkl::transpose::trans,
|
||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
*stream, oneapi::math::transpose::trans,
|
||||
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
||||
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
||||
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
||||
@ -2097,17 +2097,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||
#if !GGML_SYCL_DNNL
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
# ifdef GGML_SYCL_NVIDIA
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||
oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
|
||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
|
||||
ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||
# else
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
||||
dst_dd_i, ldc)));
|
||||
# endif
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
||||
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
||||
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
||||
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||
#else
|
||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||
@ -2836,14 +2829,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*main_stream, oneapi::mkl::transpose::trans,
|
||||
oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const char *)src0_as_f16, dpct::library_data_t::real_half,
|
||||
nb01 / nb00, nb02 / nb00,
|
||||
(const char *)src1_f16, dpct::library_data_t::real_half,
|
||||
nb11 / nb10, nb12 / nb10, beta,
|
||||
(char *)dst_t, cu_data_type, ne01, nb2 / nb0,
|
||||
ne12 * ne13, cu_compute_type)));
|
||||
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||
(const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
|
||||
cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
|
||||
} else {
|
||||
const int ne23 = ne12*ne13;
|
||||
|
||||
@ -2878,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
||||
});
|
||||
}
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||
*main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
*main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
|
||||
(void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
|
||||
|
Reference in New Issue
Block a user