mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-09 04:15:15 +00:00
metal : refactor mat-vec code (llama/12569)
* metal : refactor mat-vec code ggml-ci * metal : rename all_sum -> sum_all ggml-ci * metal : fix comments [no ci] * metal : fix nr constant [no ci] * metal : mv q6_K support nr0 > 1 ggml-ci * metal : reduce register pressure ggml-ci * metal : fix typo [no ci] * metal : reduce register pressure ggml-ci
This commit is contained in:
parent
3c4d363872
commit
f4f619ea8e
@ -1,6 +1,70 @@
|
||||
#ifndef GGML_METAL_IMPL
|
||||
#define GGML_METAL_IMPL
|
||||
|
||||
// kernel parameters for mat-vec threadgroups
|
||||
//
|
||||
// N_R0: number of src0 rows to process per simdgroup
|
||||
// N_SG: number of simdgroups per threadgroup
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
#define N_R0_Q4_1 4
|
||||
#define N_SG_Q4_1 2
|
||||
|
||||
#define N_R0_Q5_0 4
|
||||
#define N_SG_Q5_0 2
|
||||
|
||||
#define N_R0_Q5_1 4
|
||||
#define N_SG_Q5_1 2
|
||||
|
||||
#define N_R0_Q8_0 4
|
||||
#define N_SG_Q8_0 2
|
||||
|
||||
#define N_R0_Q2_K 4
|
||||
#define N_SG_Q2_K 2
|
||||
|
||||
#define N_R0_Q3_K 2
|
||||
#define N_SG_Q3_K 2
|
||||
|
||||
#define N_R0_Q4_K 4
|
||||
#define N_SG_Q4_K 2
|
||||
|
||||
#define N_R0_Q5_K 2
|
||||
#define N_SG_Q5_K 2
|
||||
|
||||
#define N_R0_Q6_K 1
|
||||
#define N_SG_Q6_K 2
|
||||
|
||||
#define N_R0_IQ1_S 4
|
||||
#define N_SG_IQ1_S 2
|
||||
|
||||
#define N_R0_IQ1_M 4
|
||||
#define N_SG_IQ1_M 2
|
||||
|
||||
#define N_R0_IQ2_XXS 4
|
||||
#define N_SG_IQ2_XXS 2
|
||||
|
||||
#define N_R0_IQ2_XS 4
|
||||
#define N_SG_IQ2_XS 2
|
||||
|
||||
#define N_R0_IQ2_S 4
|
||||
#define N_SG_IQ2_S 2
|
||||
|
||||
#define N_R0_IQ3_XXS 4
|
||||
#define N_SG_IQ3_XXS 2
|
||||
|
||||
#define N_R0_IQ3_S 4
|
||||
#define N_SG_IQ3_S 2
|
||||
|
||||
#define N_R0_IQ4_NL 2
|
||||
#define N_SG_IQ4_NL 2
|
||||
|
||||
#define N_R0_IQ4_XS 2
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
|
@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node(
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
int nrows = 1;
|
||||
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
int nsg = 0; // number of simdgroups
|
||||
int nr0 = 0; // number of src0 rows per simdgroup
|
||||
int nr1 = 1; // number of src1 rows per threadgroup
|
||||
|
||||
size_t smem = 0; // shared memory
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
nr1 = 4;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
|
||||
nrows = 4;
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
if (src1t == GGML_TYPE_F32) {
|
||||
if (ne11 * ne12 < 4) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
|
||||
nrows = ne11;
|
||||
nr1 = ne11;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
|
||||
nrows = 4;
|
||||
nr1 = 4;
|
||||
}
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
|
||||
nrows = 4;
|
||||
nr1 = 4;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
if (src1t == GGML_TYPE_F32) {
|
||||
if (ne11 * ne12 < 4) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
|
||||
nrows = ne11;
|
||||
nr1 = ne11;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
|
||||
nrows = 4;
|
||||
nr1 = 4;
|
||||
}
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
|
||||
nrows = 4;
|
||||
nr1 = 4;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q4_0;
|
||||
nr0 = N_R0_Q4_0;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q4_1;
|
||||
nr0 = N_R0_Q4_1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q5_0;
|
||||
nr0 = N_R0_Q5_0;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q5_1;
|
||||
nr0 = N_R0_Q5_1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q8_0;
|
||||
nr0 = N_R0_Q8_0;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q2_K;
|
||||
nr0 = N_R0_Q2_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q3_K;
|
||||
nr0 = N_R0_Q3_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
{
|
||||
nth0 = 4; //1;
|
||||
nth1 = 8; //32;
|
||||
nsg = N_SG_Q4_K;
|
||||
nr0 = N_R0_Q4_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q5_K;
|
||||
nr0 = N_R0_Q5_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q6_K;
|
||||
nr0 = N_R0_Q6_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ2_XXS;
|
||||
nr0 = N_R0_IQ2_XXS;
|
||||
smem = 256*8+128;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ2_XS;
|
||||
nr0 = N_R0_IQ2_XS;
|
||||
smem = 512*8+128;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ3_XXS;
|
||||
nr0 = N_R0_IQ3_XXS;
|
||||
smem = 256*4+128;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ3_S;
|
||||
nr0 = N_R0_IQ3_S;
|
||||
smem = 512*4;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ2_S;
|
||||
nr0 = N_R0_IQ2_S;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ1_S;
|
||||
nr0 = N_R0_IQ1_S;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ1_M;
|
||||
nr0 = N_R0_IQ1_M;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ4_NL;
|
||||
nr0 = N_R0_IQ4_NL;
|
||||
smem = 32*sizeof(float);
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ4_XS;
|
||||
nr0 = N_R0_IQ4_XS;
|
||||
smem = 32*sizeof(float);
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
||||
} break;
|
||||
default:
|
||||
@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
||||
const int mem_size = 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
if (smem > 0) {
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
}
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
int nrows = 1;
|
||||
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
int nsg = 0; // number of simdgroups
|
||||
int nr0 = 0; // number of src0 rows per simdgroup
|
||||
int nr1 = 1; // number of src1 rows per threadgroup
|
||||
|
||||
size_t smem = 0; // shared memory
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q4_0;
|
||||
nr0 = N_R0_Q4_0;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q4_1;
|
||||
nr0 = N_R0_Q4_1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q5_0;
|
||||
nr0 = N_R0_Q5_0;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q5_1;
|
||||
nr0 = N_R0_Q5_1;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
nsg = N_SG_Q8_0;
|
||||
nr0 = N_R0_Q8_0;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q2_K;
|
||||
nr0 = N_R0_Q2_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q3_K;
|
||||
nr0 = N_R0_Q3_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
{
|
||||
nth0 = 4; //1;
|
||||
nth1 = 8; //32;
|
||||
nsg = N_SG_Q4_K;
|
||||
nr0 = N_R0_Q4_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q5_K;
|
||||
nr0 = N_R0_Q5_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nsg = N_SG_Q6_K;
|
||||
nr0 = N_R0_Q6_K;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ2_XXS;
|
||||
nr0 = N_R0_IQ2_XXS;
|
||||
smem = 256*8+128;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ2_XS;
|
||||
nr0 = N_R0_IQ2_XS;
|
||||
smem = 512*8+128;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ3_XXS;
|
||||
nr0 = N_R0_IQ3_XXS;
|
||||
smem = 256*4+128;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ3_S;
|
||||
nr0 = N_R0_IQ3_S;
|
||||
smem = 512*4;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ2_S;
|
||||
nr0 = N_R0_IQ2_S;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ1_S;
|
||||
nr0 = N_R0_IQ1_S;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ1_M;
|
||||
nr0 = N_R0_IQ1_M;
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ4_NL;
|
||||
nr0 = N_R0_IQ4_NL;
|
||||
smem = 32*sizeof(float);
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nsg = N_SG_IQ4_XS;
|
||||
nr0 = N_R0_IQ4_XS;
|
||||
smem = 32*sizeof(float);
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
||||
} break;
|
||||
default:
|
||||
@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
if (ggml_is_quantized(src0t)) {
|
||||
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||
GGML_ASSERT(ne00 >= nsg*nr0);
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mv_id args = {
|
||||
@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||
|
||||
const int64_t _ne1 = 1;
|
||||
const int tgz = dst_rows;
|
||||
const int64_t ne123 = dst_rows;
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
||||
const int mem_size = 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
if (smem > 0) {
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
}
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user