whisper : fix tensor allocation during load

This commit is contained in:
Georgi Gerganov 2023-11-10 11:51:55 +02:00
parent 7e01486b61
commit 3dfbe64911
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1422,6 +1422,17 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
ggml_allocr * alloc_conv = ggml_allocr_new_from_buffer(model.data->buffer_conv); ggml_allocr * alloc_conv = ggml_allocr_new_from_buffer(model.data->buffer_conv);
ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main); ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main);
// allocate tensors in the backend buffers
{
for (const auto & t : model.tensors) {
if (t.first.find("conv") != std::string::npos) {
ggml_allocr_alloc(alloc_conv, t.second);
} else {
ggml_allocr_alloc(alloc_main, t.second);
}
}
}
// load weights // load weights
{ {
size_t total_size = 0; size_t total_size = 0;
@ -1484,10 +1495,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const bool is_conv = name.find("conv") != std::string::npos; const bool is_conv = name.find("conv") != std::string::npos;
ggml_allocr * alloc = is_conv ? alloc_conv : alloc_main;
ggml_backend * backend = is_conv ? wctx.backend_conv() : wctx.backend_main(); ggml_backend * backend = is_conv ? wctx.backend_conv() : wctx.backend_main();
ggml_allocr_alloc(alloc, tensor);
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
if (ggml_backend_is_cpu(backend) if (ggml_backend_is_cpu(backend)