diff --git a/ggml.c b/ggml.c index d2a8c047..f5caeba0 100644 --- a/ggml.c +++ b/ggml.c @@ -394,6 +394,12 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); +ggml_collect_imatrix_t g_imatrix_collect = NULL; + +void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect) { + g_imatrix_collect = imatrix_collect; +} + static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { .type_name = "i8", @@ -9763,6 +9769,10 @@ static void ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0, src1); + } + const enum ggml_type type = src0->type; const bool src1_cont = ggml_is_contiguous(src1); @@ -10066,6 +10076,10 @@ static void ggml_compute_forward_mul_mat_id( const struct ggml_tensor * src0_cur = dst->src[cur_a + 2]; + if (ith == 1 && g_imatrix_collect) { + g_imatrix_collect(src0_cur, src1); + } + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); diff --git a/ggml.h b/ggml.h index 93b42a27..4c2ff6c6 100644 --- a/ggml.h +++ b/ggml.h @@ -2067,6 +2067,12 @@ extern "C" { GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + // + // Importance matrix + // + typedef void(*ggml_collect_imatrix_t)(const struct ggml_tensor * src0, const struct ggml_tensor * src1); + GGML_API void ggml_set_imatrix_collection(ggml_collect_imatrix_t imatrix_collect); + // // gguf //