whisper : initial hipBLAS support (#1209)

This commit is contained in:
ardfork 2023-08-27 17:03:58 +00:00 committed by GitHub
parent b5bb5c85d4
commit cb5fb0a12d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 0 deletions

View File

@ -65,6 +65,7 @@ else()
option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic) option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic)
option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF) option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF)
option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF) option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF)
option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF)
option(WHISPER_CLBLAST "whisper: use CLBlast" OFF) option(WHISPER_CLBLAST "whisper: use CLBlast" OFF)
endif() endif()
@ -191,6 +192,37 @@ if (WHISPER_CUBLAS)
endif() endif()
endif() endif()
if (WHISPER_HIPBLAS)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()
find_package(hip)
find_package(hipblas)
find_package(rocblas)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON)
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
if (WHISPER_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
endif()
if (WHISPER_CLBLAST) if (WHISPER_CLBLAST)
find_package(CLBlast) find_package(CLBlast)
if (CLBlast_FOUND) if (CLBlast_FOUND)

View File

@ -161,6 +161,21 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif endif
ifdef WHISPER_HIPBLAS
ROCM_PATH ?= /opt/rocm
HIPCC ?= $(ROCM_PATH)/bin/hipcc
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
LDFLAGS += -lhipblas -lamdhip64 -lrocblas
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
WHISPER_OBJ += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
endif
ifdef WHISPER_CLBLAST ifdef WHISPER_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST CXXFLAGS += -DGGML_USE_CLBLAST

View File

@ -6,9 +6,60 @@
#include <atomic> #include <atomic>
#include <assert.h> #include <assert.h>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasCreate hipblasCreate
#define cublasGetStatusString rocblas_status_to_string
#define cublasHandle_t hipblasHandle_t
#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDestroy hipEventDestroy
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaMalloc hipMalloc
#define cudaMallocHost hipHostMalloc
#define cudaMemcpy hipMemcpy
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#else
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#endif
#include "ggml-cuda.h" #include "ggml-cuda.h"
#include "ggml.h" #include "ggml.h"