Refactor validation and enumeration platform checks into functions to clean up ggml_vk_instance_init()

This commit is contained in:
0cc4m 2024-02-14 20:57:17 +01:00 committed by Georgi Gerganov
parent 9fca69b410
commit 8daa534818
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1091,7 +1091,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
} }
} }
static void ggml_vk_instance_init() { static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
void ggml_vk_instance_init() {
if (vk_instance_initialized) { if (vk_instance_initialized) {
return; return;
} }
@ -1102,54 +1105,40 @@ static void ggml_vk_instance_init() {
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties(); const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
#ifdef __APPLE__ const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
bool portability_enumeration_ext = false; const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
// Check for portability enumeration extension for MoltenVK support
for (const auto& properties : instance_extensions) {
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
portability_enumeration_ext = true;
break;
}
}
if (!portability_enumeration_ext) {
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
}
#endif
std::vector<const char*> layers = { std::vector<const char*> layers;
#ifdef GGML_VULKAN_VALIDATE
"VK_LAYER_KHRONOS_validation", if (validation_ext) {
#endif layers.push_back("VK_LAYER_KHRONOS_validation");
}; }
std::vector<const char*> extensions = { std::vector<const char*> extensions;
#ifdef GGML_VULKAN_VALIDATE if (validation_ext) {
"VK_EXT_validation_features", extensions.push_back("VK_EXT_validation_features");
#endif }
};
#ifdef __APPLE__
if (portability_enumeration_ext) { if (portability_enumeration_ext) {
extensions.push_back("VK_KHR_portability_enumeration"); extensions.push_back("VK_KHR_portability_enumeration");
} }
#endif
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
#ifdef __APPLE__
if (portability_enumeration_ext) { if (portability_enumeration_ext) {
instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
} }
#endif
std::vector<vk::ValidationFeatureEnableEXT> features_enable;
vk::ValidationFeaturesEXT validation_features;
#ifdef GGML_VULKAN_VALIDATE if (validation_ext) {
const std::vector<vk::ValidationFeatureEnableEXT> features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices };
vk::ValidationFeaturesEXT validation_features = { validation_features = {
features_enable, features_enable,
{}, {},
}; };
validation_features.setPNext(nullptr); validation_features.setPNext(nullptr);
instance_create_info.setPNext(&validation_features); instance_create_info.setPNext(&validation_features);
std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
#endif }
vk_instance.instance = vk::createInstance(instance_create_info); vk_instance.instance = vk::createInstance(instance_create_info);
memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES); memset(vk_instance.initialized, 0, sizeof(bool) * GGML_VK_MAX_DEVICES);
@ -5329,6 +5318,42 @@ GGML_CALL int ggml_backend_vk_reg_devices() {
return vk_instance.device_indices.size(); return vk_instance.device_indices.size();
} }
// Extension availability
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
#ifdef GGML_VULKAN_VALIDATE
bool portability_enumeration_ext = false;
// Check for portability enumeration extension for MoltenVK support
for (const auto& properties : instance_extensions) {
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
return true;
}
}
if (!portability_enumeration_ext) {
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
}
#endif
return false;
UNUSED(instance_extensions);
}
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
#ifdef __APPLE__
bool portability_enumeration_ext = false;
// Check for portability enumeration extension for MoltenVK support
for (const auto& properties : instance_extensions) {
if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
return true;
}
}
if (!portability_enumeration_ext) {
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
}
#endif
return false;
UNUSED(instance_extensions);
}
// checks // checks
#ifdef GGML_VULKAN_CHECK_RESULTS #ifdef GGML_VULKAN_CHECK_RESULTS