mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 20:57:52 +00:00
whisper : add abort callback (#1335)
This commit is contained in:
parent
08fa34882f
commit
2f668c330e
50
whisper.cpp
50
whisper.cpp
@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
||||
// ggml helpers
|
||||
//
|
||||
|
||||
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
static void ggml_graph_compute_helper(
|
||||
std::vector<uint8_t> & buf,
|
||||
ggml_cgraph * graph,
|
||||
int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||
|
||||
plan.abort_callback = abort_callback;
|
||||
plan.abort_callback_data = abort_callback_data;
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
@ -1922,7 +1930,9 @@ static bool whisper_encode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
const int mel_offset,
|
||||
const int n_threads) {
|
||||
const int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
// conv
|
||||
@ -1936,7 +1946,7 @@ static bool whisper_encode_internal(
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
if (!whisper_encode_external(wstate)) {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1955,10 +1965,10 @@ static bool whisper_encode_internal(
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1977,10 +1987,10 @@ static bool whisper_encode_internal(
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -2346,7 +2356,9 @@ static bool whisper_decode_internal(
|
||||
const whisper_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads) {
|
||||
const int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const auto & model = wctx.model;
|
||||
@ -2375,10 +2387,10 @@ static bool whisper_decode_internal(
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
} else {
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
}
|
||||
#else
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads);
|
||||
ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -3290,7 +3302,7 @@ int whisper_set_mel(
|
||||
}
|
||||
|
||||
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
|
||||
if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) {
|
||||
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state
|
||||
}
|
||||
|
||||
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
||||
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) {
|
||||
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
|
||||
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
|
||||
const int selected_decoder_id = 0;
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) {
|
||||
if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) {
|
||||
log("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -4594,7 +4606,7 @@ int whisper_full_with_state(
|
||||
}
|
||||
|
||||
// encode audio features starting at offset seek
|
||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) {
|
||||
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
log("%s: failed to encode\n", __func__);
|
||||
return -6;
|
||||
}
|
||||
@ -4677,7 +4689,7 @@ int whisper_full_with_state(
|
||||
}
|
||||
WHISPER_PRINT_DEBUG("\n\n");
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
log("%s: failed to decode\n", __func__);
|
||||
return -7;
|
||||
}
|
||||
@ -4901,7 +4913,7 @@ int whisper_full_with_state(
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta);
|
||||
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) {
|
||||
if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
|
||||
log("%s: failed to decode\n", __func__);
|
||||
return -8;
|
||||
}
|
||||
@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
||||
double tsum = 0.0;
|
||||
|
||||
// heat-up
|
||||
ggml_graph_compute_helper(work, &gf, n_threads);
|
||||
ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr);
|
||||
|
||||
for (int i = 0; i < n_max; ++i) {
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
ggml_graph_compute_helper(work, &gf, n_threads);
|
||||
ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr);
|
||||
|
||||
const int64_t t1 = ggml_time_us();
|
||||
|
||||
|
@ -334,6 +334,11 @@ extern "C" {
|
||||
// If it returns false, the computation is aborted
|
||||
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data);
|
||||
|
||||
// Abort callback
|
||||
// If not NULL, called before ggml computation
|
||||
// If it returns true, the computation is aborted
|
||||
typedef bool (*whisper_abort_callback)(void * user_data);
|
||||
|
||||
// Logits filter callback
|
||||
// Can be used to modify the logits before sampling
|
||||
// If not NULL, called after applying temperature to logits
|
||||
@ -428,6 +433,10 @@ extern "C" {
|
||||
whisper_encoder_begin_callback encoder_begin_callback;
|
||||
void * encoder_begin_callback_user_data;
|
||||
|
||||
// called each time before ggml computation starts
|
||||
whisper_abort_callback abort_callback;
|
||||
void * abort_callback_user_data;
|
||||
|
||||
// called by each decoder to filter obtained logits
|
||||
whisper_logits_filter_callback logits_filter_callback;
|
||||
void * logits_filter_callback_user_data;
|
||||
|
Loading…
Reference in New Issue
Block a user