From 018b2d340efd2bfc118446b47ecf30fdc912d8c0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 20 Jun 2025 11:19:15 +0300 Subject: [PATCH] ggml : fix repack work size for mul_mat_id (llama/14292) ggml-ci --- ggml/src/ggml-cpu/repack.cpp | 38 ++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5c6715d5..29071929 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1163,13 +1163,24 @@ template op) { case GGML_OP_MUL_MAT: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - return true; + { + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + return true; + } case GGML_OP_MUL_MAT_ID: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. - size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; - return true; + { + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens + + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + + size += sizeof_mmid_row_mapping*ne02*(ne12 + 1); + + return true; + } default: // GGML_ABORT("fatal error"); break; @@ -1305,14 +1316,17 @@ template wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + - n_as * ne12 * sizeof(mmid_row_mapping))); + GGML_ASSERT(params->wsize >= + (GGML_PAD(nbw3, sizeof(int64_t)) + + n_as*(ne12 + 1)*sizeof(mmid_row_mapping)) + ); - auto * wdata = (char *) params->wdata; - auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); - auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + auto * wdata = (char *)params->wdata; + auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t)); - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] + // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t) + auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] // src1: float32 => param type for (int64_t i12 = 0; i12 < ne12; ++i12) {