whisper : fix multi-state Metal

This commit is contained in:
Georgi Gerganov 2023-11-12 14:24:45 +02:00
parent d029784fb0
commit 76c8b5235b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 46 additions and 68 deletions

1
.gitignore vendored
View File

@ -8,6 +8,7 @@
.DS_Store
build/
build-coreml/
build-em/
build-debug/
build-release/

View File

@ -26,7 +26,7 @@
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
#define GGML_METAL_MAX_BUFFERS 64
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
struct ggml_tensor;

View File

@ -479,6 +479,10 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
const int64_t tsize = ggml_nbytes(t);
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
ctx = t->buffer->backend->context;
}
// find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;

View File

@ -649,7 +649,6 @@ static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
auto & alloc = allocr.alloc;
auto & meta = allocr.meta;
auto & buffer = allocr.buffer;
alloc = ggml_allocr_new_measure_from_backend(backend);
@ -659,6 +658,11 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
}
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
if (allocr.alloc == nullptr) {
// this can be null if we use external encoder like CoreML or OpenVINO
return;
}
auto & alloc = allocr.alloc;
auto & buffer = allocr.buffer;
@ -883,6 +887,37 @@ static void kv_cache_free(struct whisper_kv_cache & cache) {
}
}
static ggml_backend_t whisper_backend_init(const whisper_context_params & params) {
ggml_backend_t backend_gpu = NULL;
// initialize the backends
#ifdef GGML_USE_CUBLAS
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif
#ifdef GGML_USE_METAL
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
}
}
#endif
if (backend_gpu) {
return backend_gpu;
}
return ggml_backend_cpu_init();
}
// load the model from a ggml file
//
// file format:
@ -1301,38 +1336,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
}
// init backend
{
ggml_backend_t backend_gpu = NULL;
// initialize the backends
#ifdef GGML_USE_CUBLAS
if (wctx.params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif
#ifdef GGML_USE_METAL
if (wctx.params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
}
}
#endif
if (backend_gpu) {
wctx.backend = backend_gpu;
} else {
wctx.backend = ggml_backend_cpu_init();
}
}
wctx.backend = whisper_backend_init(wctx.params);
{
size_t size_main = 0;
@ -2827,38 +2831,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
whisper_state * state = new whisper_state;
// init backend
{
ggml_backend_t backend_gpu = NULL;
// initialize the backends
#ifdef GGML_USE_CUBLAS
if (ctx->params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif
#ifdef GGML_USE_METAL
if (ctx->params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
}
}
#endif
if (backend_gpu) {
state->backend = backend_gpu;
} else {
state->backend = ggml_backend_cpu_init();
}
}
state->backend = whisper_backend_init(ctx->params);
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
@ -2957,9 +2930,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0);
}
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
state->rng = std::mt19937(0);