mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-09 20:13:14 +00:00
vulkan: Add bfloat16 support (llama/12554)
* vulkan: Add bfloat16 support This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16. The extension is required for coopmat multiply support, but matrix-vector multiply trivially promotes bf16 to fp32 and doesn't require the extension. The copy/get_rows shaders also don't require the extension. It's probably possible to fall back to non-coopmat and promote to fp32 when the extension isn't supported, but this change doesn't do that. The coopmat support also requires a glslc that supports the extension, which currently requires a custom build. * vulkan: Support bf16 tensors without the bf16 extension or coopmat support Compile a variant of the scalar mul_mm shader that will promote the bf16 values to float, and use that when either the bf16 extension or the coopmat extensions aren't available. * vulkan: bfloat16 fixes (really works without bfloat16 support now) * vulkan: fix spirv-val failure and reenable -O
This commit is contained in:
parent
df458380d6
commit
cde0e50536
@ -71,6 +71,22 @@ if (Vulkan_FOUND)
|
|||||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
|
||||||
|
# If it's not, there will be an error to stderr.
|
||||||
|
# If it's supported, set a define to indicate that we should compile those shaders
|
||||||
|
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
||||||
|
OUTPUT_VARIABLE glslc_output
|
||||||
|
ERROR_VARIABLE glslc_error)
|
||||||
|
|
||||||
|
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
|
||||||
|
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
|
||||||
|
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
|
||||||
|
else()
|
||||||
|
message(STATUS "GL_EXT_bfloat16 supported by glslc")
|
||||||
|
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
|
||||||
|
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
endif()
|
||||||
|
|
||||||
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
||||||
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
||||||
@ -142,6 +158,7 @@ if (Vulkan_FOUND)
|
|||||||
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
||||||
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
||||||
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
|
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
|
||||||
|
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
|
||||||
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
||||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||||
|
@ -51,6 +51,24 @@
|
|||||||
|
|
||||||
#include "ggml-vulkan-shaders.hpp"
|
#include "ggml-vulkan-shaders.hpp"
|
||||||
|
|
||||||
|
// remove this once it's more widely available in the SDK
|
||||||
|
#if !defined(VK_KHR_shader_bfloat16)
|
||||||
|
|
||||||
|
#define VK_KHR_shader_bfloat16 1
|
||||||
|
#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
|
||||||
|
#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
|
||||||
|
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
|
||||||
|
#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
|
||||||
|
|
||||||
|
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
||||||
|
VkStructureType sType;
|
||||||
|
void* pNext;
|
||||||
|
VkBool32 shaderBFloat16Type;
|
||||||
|
VkBool32 shaderBFloat16DotProduct;
|
||||||
|
VkBool32 shaderBFloat16CooperativeMatrix;
|
||||||
|
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
||||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||||
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||||
@ -266,8 +284,9 @@ struct vk_device_struct {
|
|||||||
bool subgroup_require_full_support;
|
bool subgroup_require_full_support;
|
||||||
|
|
||||||
bool coopmat_support;
|
bool coopmat_support;
|
||||||
bool coopmat_acc_f32_support;
|
bool coopmat_acc_f32_support {};
|
||||||
bool coopmat_acc_f16_support;
|
bool coopmat_acc_f16_support {};
|
||||||
|
bool coopmat_bf16_support {};
|
||||||
uint32_t coopmat_m;
|
uint32_t coopmat_m;
|
||||||
uint32_t coopmat_n;
|
uint32_t coopmat_n;
|
||||||
uint32_t coopmat_k;
|
uint32_t coopmat_k;
|
||||||
@ -293,6 +312,7 @@ struct vk_device_struct {
|
|||||||
|
|
||||||
vk_matmul_pipeline pipeline_matmul_f32 {};
|
vk_matmul_pipeline pipeline_matmul_f32 {};
|
||||||
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
||||||
|
vk_matmul_pipeline pipeline_matmul_bf16 {};
|
||||||
vk_matmul_pipeline2 pipeline_matmul_f16;
|
vk_matmul_pipeline2 pipeline_matmul_f16;
|
||||||
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
||||||
|
|
||||||
@ -301,6 +321,7 @@ struct vk_device_struct {
|
|||||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
||||||
|
|
||||||
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
||||||
|
vk_matmul_pipeline pipeline_matmul_id_bf16 {};
|
||||||
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
||||||
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
||||||
|
|
||||||
@ -333,8 +354,8 @@ struct vk_device_struct {
|
|||||||
vk_pipeline pipeline_clamp_f32;
|
vk_pipeline pipeline_clamp_f32;
|
||||||
vk_pipeline pipeline_pad_f32;
|
vk_pipeline pipeline_pad_f32;
|
||||||
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
||||||
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
|
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
|
||||||
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
|
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
|
||||||
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_norm_f32;
|
vk_pipeline pipeline_norm_f32;
|
||||||
@ -1811,6 +1832,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
if (!device->pipeline_matmul_id_f32) {
|
if (!device->pipeline_matmul_id_f32) {
|
||||||
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
}
|
}
|
||||||
|
if (!device->pipeline_matmul_bf16) {
|
||||||
|
device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
|
}
|
||||||
|
if (!device->pipeline_matmul_id_bf16) {
|
||||||
|
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::future<void>> compiles;
|
std::vector<std::future<void>> compiles;
|
||||||
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
||||||
@ -1920,6 +1947,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||||
|
|
||||||
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
@ -1941,6 +1973,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||||
|
|
||||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||||
@ -1994,6 +2031,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (device->coopmat_acc_f16_support) {
|
if (device->coopmat_acc_f16_support) {
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
@ -2042,6 +2084,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
if (device->coopmat_acc_f16_support) {
|
if (device->coopmat_acc_f16_support) {
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
@ -2124,6 +2171,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
@ -2159,6 +2208,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
@ -2211,6 +2262,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||||
@ -2246,6 +2299,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
|
||||||
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
@ -2266,8 +2321,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
#undef CREATE_MM
|
|
||||||
}
|
}
|
||||||
|
// reusing CREATE_MM from the fp32 path
|
||||||
|
if ((device->coopmat2 || device->coopmat_support)
|
||||||
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
|
&& !device->coopmat_bf16_support
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
|
// use scalar tile sizes
|
||||||
|
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||||
|
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
||||||
|
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||||
|
|
||||||
|
l_wg_denoms = {128, 128, 1 };
|
||||||
|
m_wg_denoms = { 64, 64, 1 };
|
||||||
|
s_wg_denoms = { 32, 32, 1 };
|
||||||
|
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||||
|
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
||||||
|
}
|
||||||
|
#undef CREATE_MM
|
||||||
|
|
||||||
// mul mat vec
|
// mul mat vec
|
||||||
|
|
||||||
@ -2286,6 +2359,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
@ -2308,6 +2382,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
@ -2331,6 +2406,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
||||||
@ -2376,6 +2452,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
// get_rows
|
// get_rows
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
@ -2393,6 +2470,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||||
@ -2430,10 +2508,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
if (device->float_controls_rte_fp16) {
|
if (device->float_controls_rte_fp16) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
||||||
@ -2601,6 +2682,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
bool coopmat2_support = false;
|
bool coopmat2_support = false;
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
device->integer_dot_product = false;
|
device->integer_dot_product = false;
|
||||||
|
bool bfloat16_support = false;
|
||||||
|
|
||||||
for (const auto& properties : ext_props) {
|
for (const auto& properties : ext_props) {
|
||||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||||
@ -2631,6 +2713,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||||
device->integer_dot_product = true;
|
device->integer_dot_product = true;
|
||||||
#endif
|
#endif
|
||||||
|
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
||||||
|
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||||
|
bfloat16_support = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2817,6 +2902,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(VK_KHR_shader_bfloat16)
|
||||||
|
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
||||||
|
bfloat16_features.pNext = nullptr;
|
||||||
|
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
|
||||||
|
if (bfloat16_support) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
|
||||||
|
last_struct = (VkBaseOutStructure *)&bfloat16_features;
|
||||||
|
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
||||||
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
||||||
if (maintenance4_support) {
|
if (maintenance4_support) {
|
||||||
@ -3014,6 +3110,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
device->coopmat_int_n = prop.NSize;
|
device->coopmat_int_n = prop.NSize;
|
||||||
device->coopmat_int_k = prop.KSize;
|
device->coopmat_int_k = prop.KSize;
|
||||||
}
|
}
|
||||||
|
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||||
|
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||||
|
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||||
|
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
||||||
|
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
||||||
|
) {
|
||||||
|
// coopmat sizes not set yet
|
||||||
|
if (device->coopmat_m == 0) {
|
||||||
|
device->coopmat_bf16_support = true;
|
||||||
|
device->coopmat_m = prop.MSize;
|
||||||
|
device->coopmat_n = prop.NSize;
|
||||||
|
device->coopmat_k = prop.KSize;
|
||||||
|
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
||||||
|
// Only enable if shape is identical
|
||||||
|
device->coopmat_bf16_support = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
||||||
@ -3021,11 +3136,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
}
|
}
|
||||||
|
if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||||
|
device->coopmat_bf16_support = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device->coopmat_support) {
|
if (device->coopmat_support) {
|
||||||
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
||||||
}
|
}
|
||||||
|
#if defined(VK_KHR_shader_bfloat16)
|
||||||
|
if (device->coopmat_bf16_support) {
|
||||||
|
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
device->name = GGML_VK_NAME + std::to_string(idx);
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
||||||
|
|
||||||
@ -3482,6 +3605,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
||||||
return ctx->device->pipeline_matmul_f32_f16;
|
return ctx->device->pipeline_matmul_f32_f16;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
||||||
|
return ctx->device->pipeline_matmul_bf16;
|
||||||
|
}
|
||||||
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
||||||
@ -3553,6 +3679,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|||||||
switch (a_type) {
|
switch (a_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
@ -3585,6 +3712,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|||||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_matmul_id_f32;
|
return ctx->device->pipeline_matmul_id_f32;
|
||||||
}
|
}
|
||||||
|
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
||||||
|
return ctx->device->pipeline_matmul_id_bf16;
|
||||||
|
}
|
||||||
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
||||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
||||||
@ -3638,6 +3768,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|||||||
switch (a_type) {
|
switch (a_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
@ -4373,6 +4504,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|||||||
return ctx->device->pipeline_cpy_f16_f16;
|
return ctx->device->pipeline_cpy_f16_f16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
|
||||||
|
if (contig) {
|
||||||
|
return ctx->device->pipeline_contig_cpy_f32_bf16;
|
||||||
|
} else {
|
||||||
|
return ctx->device->pipeline_cpy_f32_bf16;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (src->type == GGML_TYPE_F32) {
|
if (src->type == GGML_TYPE_F32) {
|
||||||
switch (to) {
|
switch (to) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
@ -4500,8 +4638,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||||
!ggml_vk_dim01_contiguous(src0);
|
!ggml_vk_dim01_contiguous(src0);
|
||||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||||
|
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
||||||
!ggml_vk_dim01_contiguous(src1);
|
!ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
|
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
||||||
|
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
||||||
|
|
||||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||||
|
|
||||||
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
||||||
@ -4511,25 +4653,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
|
|
||||||
if (mmp == nullptr) {
|
if (mmp == nullptr) {
|
||||||
// Fall back to f16 dequant mul mat
|
// Fall back to f16 dequant mul mat
|
||||||
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
||||||
quantize_y = false;
|
quantize_y = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||||
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
|
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
||||||
|
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
// Fall back to dequant + f16 mulmat
|
// Fall back to dequant + f16 mulmat
|
||||||
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
||||||
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
||||||
|
|
||||||
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||||
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
||||||
@ -4550,12 +4692,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||||||
vk_pipeline to_q8_1 = nullptr;
|
vk_pipeline to_q8_1 = nullptr;
|
||||||
|
|
||||||
if (x_non_contig) {
|
if (x_non_contig) {
|
||||||
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
}
|
}
|
||||||
@ -5055,7 +5197,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||||||
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
||||||
// when ne12 and ne13 are one.
|
// when ne12 and ne13 are one.
|
||||||
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
||||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
|
||||||
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
@ -5123,27 +5265,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||||||
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||||
!ggml_vk_dim01_contiguous(src0);
|
!ggml_vk_dim01_contiguous(src0);
|
||||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||||
|
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
||||||
!ggml_vk_dim01_contiguous(src1);
|
!ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
|
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
||||||
|
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
||||||
|
|
||||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||||
|
|
||||||
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
||||||
|
|
||||||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
||||||
|
|
||||||
if (qx_needs_dequant) {
|
if (qx_needs_dequant) {
|
||||||
// Fall back to dequant + f16 mulmat
|
// Fall back to dequant + f16 mulmat
|
||||||
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not implemented
|
// Not implemented
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
|
|
||||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
||||||
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
||||||
|
|
||||||
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
||||||
|
|
||||||
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||||
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
||||||
@ -5162,12 +5308,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||||||
vk_pipeline to_fp16_vk_1 = nullptr;
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||||
|
|
||||||
if (x_non_contig) {
|
if (x_non_contig) {
|
||||||
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig) {
|
||||||
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
|
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
}
|
}
|
||||||
@ -9295,6 +9441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||||||
switch (src0_type) {
|
switch (src0_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
@ -9330,10 +9477,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||||||
if (a->ne[3] != b->ne[3]) {
|
if (a->ne[3] != b->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
|
||||||
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
|
||||||
|
// We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
|
||||||
|
// So don't support this combination for now.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
} break;
|
} break;
|
||||||
@ -9406,6 +9558,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
@ -9436,6 +9589,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||||||
switch (src1_type) {
|
switch (src1_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
|
@ -12,6 +12,9 @@ endif()
|
|||||||
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
endif()
|
endif()
|
||||||
|
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
endif()
|
||||||
set(TARGET vulkan-shaders-gen)
|
set(TARGET vulkan-shaders-gen)
|
||||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
@ -18,7 +18,11 @@ void main() {
|
|||||||
// fast path for when all four iterations are in-bounds
|
// fast path for when all four iterations are in-bounds
|
||||||
if (idx + (num_iter-1)*num_threads < p.ne) {
|
if (idx + (num_iter-1)*num_threads < p.ne) {
|
||||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
|
||||||
|
#if defined(DATA_D_BF16)
|
||||||
|
float f = float(data_a[get_aoffset() + idx]);
|
||||||
|
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
||||||
|
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||||
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||||
#else
|
#else
|
||||||
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||||
@ -31,7 +35,10 @@ void main() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#if defined(DATA_D_BF16)
|
||||||
|
float f = float(data_a[get_aoffset() + idx]);
|
||||||
|
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
|
||||||
|
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||||
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||||
#else
|
#else
|
||||||
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||||
|
@ -12,7 +12,10 @@ void main() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#if defined(DATA_D_BF16)
|
||||||
|
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
|
||||||
|
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
|
||||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
#else
|
#else
|
||||||
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
||||||
|
@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16)
|
||||||
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
|
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0)
|
#if defined(DATA_A_Q4_0)
|
||||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||||
vec2 get_dm(uint ib, uint a_offset) {
|
vec2 get_dm(uint ib, uint a_offset) {
|
||||||
return vec2(0, 0);
|
return vec2(0, 0);
|
||||||
}
|
}
|
||||||
|
@ -20,9 +20,14 @@ void main() {
|
|||||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#if defined(DATA_A_BF16)
|
||||||
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
|
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||||
#else
|
#else
|
||||||
data_d[d_offset + i00] = data_a[a_offset + i00];
|
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||||
|
#endif
|
||||||
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
|
data_d[d_offset + i00] = D_TYPE(v);
|
||||||
|
#else
|
||||||
|
data_d[d_offset + i00] = D_TYPE(v);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
|
||||||
#define K_PER_ITER 8
|
#define K_PER_ITER 8
|
||||||
#else
|
#else
|
||||||
#define K_PER_ITER 2
|
#define K_PER_ITER 2
|
||||||
|
@ -10,6 +10,10 @@
|
|||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16) && defined(COOPMAT)
|
||||||
|
#extension GL_EXT_bfloat16 : enable
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
#extension GL_KHR_cooperative_matrix : enable
|
#extension GL_KHR_cooperative_matrix : enable
|
||||||
#extension GL_KHR_memory_scope_semantics : enable
|
#extension GL_KHR_memory_scope_semantics : enable
|
||||||
@ -29,6 +33,10 @@
|
|||||||
#define LOAD_VEC_B 1
|
#define LOAD_VEC_B 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if !defined(TO_FLOAT_TYPE)
|
||||||
|
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||||
|
#endif
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
@ -202,8 +210,8 @@ void main() {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||||
@ -248,6 +256,21 @@ void main() {
|
|||||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
#elif defined(DATA_A_BF16)
|
||||||
|
#if LOAD_VEC_A == 4
|
||||||
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||||
|
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
|
||||||
|
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
|
||||||
|
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
|
||||||
|
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
|
||||||
|
#else
|
||||||
|
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||||
|
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||||
|
} else {
|
||||||
|
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#elif defined(DATA_A_Q4_0)
|
#elif defined(DATA_A_Q4_0)
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
||||||
@ -695,13 +718,13 @@ void main() {
|
|||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||||
#endif
|
#endif
|
||||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
||||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
|
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
|
||||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
|
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
|
||||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
|
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
|
||||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
|
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
|
||||||
#elif !MUL_MAT_ID
|
#elif !MUL_MAT_ID
|
||||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||||
} else {
|
} else {
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
@ -709,7 +732,7 @@ void main() {
|
|||||||
const uint row_i = ic * BN + loadc_b + l;
|
const uint row_i = ic * BN + loadc_b + l;
|
||||||
if (row_i < _ne1) {
|
if (row_i < _ne1) {
|
||||||
const u16vec2 row_idx = row_ids[row_i];
|
const u16vec2 row_idx = row_ids[row_i];
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||||
} else {
|
} else {
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
#extension GL_EXT_buffer_reference : enable
|
#extension GL_EXT_buffer_reference : enable
|
||||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||||
#extension GL_KHR_shader_subgroup_vote : enable
|
#extension GL_KHR_shader_subgroup_vote : enable
|
||||||
|
#ifdef DATA_A_BF16
|
||||||
|
#extension GL_EXT_bfloat16 : enable
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
|
||||||
@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|||||||
#define store_scales(a)
|
#define store_scales(a)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16)
|
||||||
|
#define MAT_TYPE bfloat16_t
|
||||||
|
#else
|
||||||
|
#define MAT_TYPE FLOAT_TYPE
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||||
|
|
||||||
@ -271,8 +280,8 @@ void main() {
|
|||||||
|
|
||||||
// Manually partial unroll
|
// Manually partial unroll
|
||||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||||
@ -286,8 +295,8 @@ void main() {
|
|||||||
store_scales(tid);
|
store_scales(tid);
|
||||||
}
|
}
|
||||||
while (block_k < end_k) {
|
while (block_k < end_k) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
||||||
@ -310,8 +319,8 @@ void main() {
|
|||||||
|
|
||||||
// Manually partial unroll
|
// Manually partial unroll
|
||||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||||
@ -325,8 +334,8 @@ void main() {
|
|||||||
store_scales(tid);
|
store_scales(tid);
|
||||||
}
|
}
|
||||||
while (block_k < end_k) {
|
while (block_k < end_k) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
||||||
@ -350,8 +359,8 @@ void main() {
|
|||||||
|
|
||||||
// Manually partial unroll
|
// Manually partial unroll
|
||||||
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
@ -365,8 +374,8 @@ void main() {
|
|||||||
store_scales(tid);
|
store_scales(tid);
|
||||||
}
|
}
|
||||||
while (block_k < end_k) {
|
while (block_k < end_k) {
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||||
@ -405,8 +414,8 @@ void main() {
|
|||||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||||
|
|
||||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
#version 460
|
||||||
|
|
||||||
|
#extension GL_EXT_bfloat16 : require
|
||||||
|
|
||||||
|
void main()
|
||||||
|
{
|
||||||
|
}
|
@ -33,6 +33,19 @@
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_BF16)
|
||||||
|
#define QUANT_K 1
|
||||||
|
#define QUANT_R 1
|
||||||
|
|
||||||
|
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||||
|
#define A_TYPE uint16_t
|
||||||
|
#elif LOAD_VEC_A == 4
|
||||||
|
#define A_TYPE u16vec4
|
||||||
|
#elif LOAD_VEC_A == 8
|
||||||
|
#error unsupported
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
#define QUANT_K_Q4_0 32
|
#define QUANT_K_Q4_0 32
|
||||||
#define QUANT_R_Q4_0 2
|
#define QUANT_R_Q4_0 2
|
||||||
|
|
||||||
@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// returns the bfloat value in the low 16b.
|
||||||
|
// See ggml_compute_fp32_to_bf16
|
||||||
|
uint32_t fp32_to_bf16(float f)
|
||||||
|
{
|
||||||
|
uint32_t u = floatBitsToUint(f);
|
||||||
|
u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
float bf16_to_fp32(uint32_t u)
|
||||||
|
{
|
||||||
|
return uintBitsToFloat(u << 16);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // !defined(GGML_TYPES_COMP)
|
#endif // !defined(GGML_TYPES_COMP)
|
||||||
|
@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
|
|||||||
"iq3_xxs",
|
"iq3_xxs",
|
||||||
"iq3_s",
|
"iq3_s",
|
||||||
"iq4_xs",
|
"iq4_xs",
|
||||||
"iq4_nl"
|
"iq4_nl",
|
||||||
|
"bf16",
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||||
|
|
||||||
std::map<std::string, std::string> base_dict = {
|
std::map<std::string, std::string> base_dict = {
|
||||||
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
|
||||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
||||||
};
|
};
|
||||||
std::string shader_name = "matmul";
|
std::string shader_name = "matmul";
|
||||||
@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||||||
|
|
||||||
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||||
|
|
||||||
// Shaders with f16 B_TYPE
|
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
|
||||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
if (t == "bf16") {
|
||||||
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
// scalar path promotes to float
|
||||||
|
if (!coopmat && !coopmat2) {
|
||||||
|
return "float";
|
||||||
|
}
|
||||||
|
return "bfloat16_t";
|
||||||
|
}
|
||||||
|
if (coopmat2 || fp16) {
|
||||||
|
return "float16_t";
|
||||||
|
}
|
||||||
|
return "float";
|
||||||
|
};
|
||||||
|
|
||||||
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
// Shaders with f16 B_TYPE
|
||||||
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
||||||
|
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
||||||
|
// bf16
|
||||||
|
{
|
||||||
|
std::string load_vec_a_unaligned = "1";
|
||||||
|
// For aligned matmul loads
|
||||||
|
std::string load_vec_a = coopmat2 ? "1" : "4";
|
||||||
|
|
||||||
|
// scalar path promotes to float
|
||||||
|
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
||||||
|
|
||||||
|
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
||||||
|
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
|
if (!(coopmat || coopmat2))
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto& tname : type_names) {
|
for (const auto& tname : type_names) {
|
||||||
std::string load_vec_quant = "2";
|
std::string load_vec_quant = "2";
|
||||||
@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||||||
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
|
||||||
load_vec_quant = "4";
|
load_vec_quant = "4";
|
||||||
|
|
||||||
|
if (tname == "bf16") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
|
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||||
// For aligned matmul loads
|
// For aligned matmul loads
|
||||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
|
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||||
|
|
||||||
// don't generate f32 variants for coopmat2
|
// don't generate f32 variants for coopmat2
|
||||||
if (!coopmat2) {
|
if (!coopmat2) {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tname != "f16" && tname != "f32") {
|
if (tname != "f16" && tname != "f32") {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -393,6 +430,7 @@ void process_shaders() {
|
|||||||
if (tname == "f32") {
|
if (tname == "f32") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (tname == "bf16") continue;
|
||||||
|
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||||
@ -417,12 +455,12 @@ void process_shaders() {
|
|||||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
// Dequant shaders
|
// Dequant shaders
|
||||||
if (tname != "f16") {
|
if (tname != "f16" && tname != "bf16") {
|
||||||
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!string_ends_with(tname, "_k")) {
|
if (!string_ends_with(tname, "_k")) {
|
||||||
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||||
|
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||||
@ -447,9 +485,11 @@ void process_shaders() {
|
|||||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||||
|
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||||
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||||
|
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||||
|
|
||||||
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user