mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-22 08:30:07 +00:00
ggml : fix repack work size for mul_mat_id (llama/14292)
ggml-ci
This commit is contained in:
@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||||||
// not realy a GGML_TYPE_Q8_0 but same size.
|
// not realy a GGML_TYPE_Q8_0 but same size.
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
|
{
|
||||||
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
{
|
||||||
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
||||||
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
||||||
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
|
||||||
|
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
|
||||||
|
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
|
||||||
|
|
||||||
|
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
|
||||||
|
|
||||||
|
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
// GGML_ABORT("fatal error");
|
// GGML_ABORT("fatal error");
|
||||||
break;
|
break;
|
||||||
@ -1305,13 +1316,16 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||||||
int32_t i2;
|
int32_t i2;
|
||||||
};
|
};
|
||||||
|
|
||||||
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
GGML_ASSERT(params->wsize >=
|
||||||
n_as * ne12 * sizeof(mmid_row_mapping)));
|
(GGML_PAD(nbw3, sizeof(int64_t)) +
|
||||||
|
n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
|
||||||
|
);
|
||||||
|
|
||||||
auto * wdata = (char *) params->wdata;
|
auto * wdata = (char *)params->wdata;
|
||||||
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
||||||
|
|
||||||
|
// total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
|
||||||
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||||
|
|
||||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
||||||
|
|
||||||
// src1: float32 => param type
|
// src1: float32 => param type
|
||||||
|
Reference in New Issue
Block a user