ggml : fix ggml_nbytes (probably temp solution)

This commit is contained in:
Georgi Gerganov 2023-09-12 20:10:53 +03:00
parent 79a88057bd
commit 9fdd415367
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

39
ggml.c
View File

@ -4303,10 +4303,43 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
}
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
// original:
//size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
//for (int i = 1; i < GGML_MAX_DIMS; ++i) {
// nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
//}
//return nbytes;
// TODO: the imlpementation below is stupid - need something better
// sort ne and nb
int64_t sne[GGML_MAX_DIMS];
size_t snb[GGML_MAX_DIMS];
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
sne[i] = tensor->ne[i];
snb[i] = tensor->nb[i];
}
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
for (int j = i + 1; j < GGML_MAX_DIMS; ++j) {
if ((snb[i] > snb[j]) || (snb[i] == snb[j] && sne[i] < sne[j])) {
size_t tmp = snb[i];
snb[i] = snb[j];
snb[j] = tmp;
int64_t tmp2 = sne[i];
sne[i] = sne[j];
sne[j] = tmp2;
}
}
}
size_t nbytes = (sne[0]/ggml_blck_size(tensor->type))*snb[0];
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (sne[i] - 1)*snb[i];
}
return nbytes;
}