#!/usr/bin/env python3 from glob import glob import os TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"] SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-vec-f{vkq_size}.cuh" DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); """ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../fattn-mma-f16.cuh" """ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. #include "../mmq.cuh" DECL_MMQ_CASE({type}); """ def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() def get_head_sizes(type_k, type_v): if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16": return [64, 128, 256] if type_k == "GGML_TYPE_F16": return [64, 128] return [128] for filename in glob("*.cu"): os.remove(filename) for vkq_size in [16, 32]: for type_k in TYPES_KV: for type_v in TYPES_KV: for head_size in get_head_sizes(type_k, type_v): with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) for ncols in [8, 16, 32, 64]: for ncols2 in [1, 2, 4, 8, 16]: if ncols2 > ncols: continue ncols1 = ncols // ncols2 with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: if head_size_kq != 576 and ncols2 == 16: continue if head_size_kq == 576 and ncols2 != 16: continue head_size_v = head_size_kq if head_size_kq != 576 else 512 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type))