mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-03-10 14:34:01 +00:00
command : grammar-related improvements
- option to read grammar from file - add sample grammars for colors and chess moves - fine-tune the performance further
This commit is contained in:
parent
b8f34d1ed7
commit
54d168db67
@ -22,6 +22,11 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
bool file_exists(const std::string & fname) {
|
||||
std::ifstream f(fname.c_str());
|
||||
return f.good();
|
||||
}
|
||||
|
||||
// command-line parameters
|
||||
struct whisper_params {
|
||||
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
@ -36,6 +41,8 @@ struct whisper_params {
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
grammar_parser::parse_state grammar_parsed;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool print_special = false;
|
||||
@ -117,15 +124,18 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
|
||||
std::string transcribe(
|
||||
whisper_context * ctx,
|
||||
const whisper_params & params,
|
||||
const std::vector<float> & pcmf32,
|
||||
const std::string & grammar_rule,
|
||||
float & prob,
|
||||
int64_t & t_ms) {
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
prob = 0.0f;
|
||||
t_ms = 0;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
std::vector<const whisper_grammar_element *> grammar_rules;
|
||||
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_progress = false;
|
||||
@ -140,17 +150,20 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
// disable fallback - seems not useful for command recognition
|
||||
wparams.temperature_inc = 0.0f;
|
||||
wparams.temperature_inc = 0.00f;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
grammar_rules = parsed_grammar.c_rules();
|
||||
//wparams.initial_prompt = params.prompt.data();
|
||||
|
||||
const auto & grammar_parsed = params.grammar_parsed;
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (!params.grammar_parsed.rules.empty()) {
|
||||
wparams.grammar_rules = grammar_rules.data();
|
||||
wparams.n_grammar_rules = grammar_rules.size();
|
||||
wparams.i_start_rule = parsed_grammar.symbol_ids.at("root");
|
||||
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
|
||||
wparams.grammar_penalty = params.grammar_penalty;
|
||||
}
|
||||
|
||||
@ -270,7 +283,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
|
||||
fprintf(stderr, " ]\n");
|
||||
}
|
||||
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
std::string k_prompt = "select one from the available words: ";
|
||||
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
||||
if (i > 0) {
|
||||
k_prompt += ", ";
|
||||
@ -476,7 +489,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
// detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", prob, t_ms));
|
||||
|
||||
const auto words = get_words(txt);
|
||||
|
||||
@ -523,9 +536,10 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
//const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
//const std::string k_prompt = "Начало.";
|
||||
const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди.";
|
||||
std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
if (!params.prompt.empty()) {
|
||||
k_prompt = params.prompt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@ -558,7 +572,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// wait for activation phrase
|
||||
audio.get(params.prompt_ms, pcmf32_cur);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob0, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
@ -581,13 +595,16 @@ int process_general_transcription(struct whisper_context * ctx, audio_async & au
|
||||
// we have heard the activation phrase, now detect the commands
|
||||
audio.get(params.command_ms, pcmf32_cur);
|
||||
|
||||
//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);
|
||||
|
||||
// prepend 3 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
// append 1 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
|
||||
@ -688,13 +705,23 @@ int main(int argc, char ** argv) {
|
||||
int ret_val = 0;
|
||||
|
||||
if (!params.grammar.empty()) {
|
||||
auto parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
auto & grammar = params.grammar_parsed;
|
||||
if (file_exists(params.grammar.c_str())) {
|
||||
// read grammar from file
|
||||
std::ifstream ifs(params.grammar.c_str());
|
||||
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
|
||||
grammar = grammar_parser::parse(txt.c_str());
|
||||
} else {
|
||||
// read grammar from string
|
||||
grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
}
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
if (grammar.rules.empty()) {
|
||||
ret_val = 1;
|
||||
} else {
|
||||
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||
grammar_parser::print_grammar(stderr, grammar);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
@ -702,7 +729,7 @@ int main(int argc, char ** argv) {
|
||||
if (ret_val == 0) {
|
||||
if (!params.commands.empty()) {
|
||||
ret_val = process_command_list(ctx, audio, params);
|
||||
} else if (!params.prompt.empty()) {
|
||||
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
|
||||
ret_val = always_prompt_transcription(ctx, audio, params);
|
||||
} else {
|
||||
ret_val = process_general_transcription(ctx, audio, params);
|
||||
|
@ -413,7 +413,7 @@ namespace grammar_parser {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() {
|
||||
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
|
||||
std::vector<const whisper_grammar_element *> ret;
|
||||
for (const auto & rule : rules) {
|
||||
ret.push_back(rule.data());
|
||||
|
@ -21,7 +21,7 @@ namespace grammar_parser {
|
||||
std::map<std::string, uint32_t> symbol_ids;
|
||||
std::vector<std::vector<whisper_grammar_element>> rules;
|
||||
|
||||
std::vector<const whisper_grammar_element *> c_rules();
|
||||
std::vector<const whisper_grammar_element *> c_rules() const;
|
||||
};
|
||||
|
||||
parse_state parse(const char * src);
|
||||
|
27
grammars/chess.gbnf
Normal file
27
grammars/chess.gbnf
Normal file
@ -0,0 +1,27 @@
|
||||
# - bishop to c3
|
||||
# - rook to d4
|
||||
# - knight to e5
|
||||
# - d4 d5 knight to c3
|
||||
# - c3 queen to d4 king b1
|
||||
# - pawn to a1 bishop to b2 knight to c3
|
||||
#
|
||||
# initial prompt:
|
||||
#
|
||||
# "pawn to a1, bishop to b2, knight to c3, rook to d4, queen to e5, king to f6,"
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6"
|
||||
#
|
||||
|
||||
root ::= init (move? move? move? ".")
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " pawn knight king a1 f5 h6"
|
||||
|
||||
move ::= " " ((piece | pawn | king) " " "to "?)? [a-h] [1-8]
|
||||
|
||||
piece ::= "bishop" | "rook" | "knight" | "queen"
|
||||
king ::= "king"
|
||||
pawn ::= "pawn"
|
24
grammars/colors.gbnf
Normal file
24
grammars/colors.gbnf
Normal file
@ -0,0 +1,24 @@
|
||||
# - red
|
||||
# - green
|
||||
# - blue
|
||||
# - red green
|
||||
# - red blue
|
||||
# - green red
|
||||
# - green blue green
|
||||
#
|
||||
# initial prompt:
|
||||
#
|
||||
# "red green blue"
|
||||
#
|
||||
# example:
|
||||
#
|
||||
# ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue"
|
||||
#
|
||||
|
||||
root ::= init color (color)? (color)? "."
|
||||
prompt ::= init "."
|
||||
|
||||
# leading space is very important!
|
||||
init ::= " red green blue"
|
||||
|
||||
color ::= " " ("red" | "green" | "blue")
|
65
whisper.cpp
65
whisper.cpp
@ -3865,28 +3865,29 @@ static struct whisper_grammar whisper_grammar_init(
|
||||
static void whisper_suppress_invalid_grammar(
|
||||
whisper_context & ctx,
|
||||
const whisper_full_params & params,
|
||||
std::vector<float> & logprobs,
|
||||
std::vector<float> & logits,
|
||||
const whisper_grammar & grammar) {
|
||||
|
||||
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// bool allow_eot = false;
|
||||
// for (const auto & stack : grammar.stacks) {
|
||||
// if (stack.empty()) {
|
||||
// allow_eot = true;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
bool allow_eot = false;
|
||||
for (const auto & stack : grammar.stacks) {
|
||||
if (stack.empty()) {
|
||||
allow_eot = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const whisper_token eot = whisper_token_eot(&ctx);
|
||||
|
||||
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
|
||||
std::vector<whisper_grammar_candidate> candidates_grammar;
|
||||
|
||||
size_t size = logprobs.size();
|
||||
for (whisper_token id = 0; id < (int) size; ++id) {
|
||||
for (whisper_token id = 0; id < eot; ++id) {
|
||||
const std::string & text = ctx.vocab.id_to_token[id];
|
||||
if (!text.empty() && text.rfind("[_", 0) != 0) {
|
||||
if (!text.empty()) {
|
||||
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
|
||||
candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||
}
|
||||
@ -3895,14 +3896,12 @@ static void whisper_suppress_invalid_grammar(
|
||||
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
||||
|
||||
for (const auto & reject : rejects) {
|
||||
logprobs[reject.id] -= params.grammar_penalty;
|
||||
logits[reject.id] -= params.grammar_penalty;
|
||||
}
|
||||
|
||||
// when the grammar does not allow any continuation, we don't want to penalize the EOT token
|
||||
// TODO: is there are better way to do this?
|
||||
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
|
||||
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
|
||||
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
|
||||
// when the grammar allows a continuation, we penalize the end-of-text token
|
||||
if (!allow_eot) {
|
||||
logits[eot] -= params.grammar_penalty;
|
||||
}
|
||||
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
||||
}
|
||||
@ -3912,7 +3911,7 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
|
||||
//fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
|
||||
|
||||
const std::string & text = ctx.vocab.id_to_token[token];
|
||||
|
||||
@ -4308,14 +4307,28 @@ static void whisper_process_logits(
|
||||
logits[i] = -INFINITY;
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
//printf("sampling text\n");
|
||||
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
||||
logits[i] = -INFINITY;
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
} else if (params.n_grammar_rules > 0) {
|
||||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
||||
|
||||
whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar);
|
||||
// populate the logprobs array (log_softmax)
|
||||
{
|
||||
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
||||
float logsumexp = 0.0f;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -INFINITY) {
|
||||
logsumexp += expf(logits[i] - logit_max);
|
||||
}
|
||||
}
|
||||
logsumexp = logf(logsumexp) + logit_max;
|
||||
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -INFINITY) {
|
||||
logprobs[i] = logits[i] - logsumexp;
|
||||
} else {
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4331,7 +4344,7 @@ static void whisper_process_logits(
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// print first 100 logits - token string : logit
|
||||
//for (int i = 0; i < 10; i++) {
|
||||
// const auto token = vocab.id_to_token.at(i);
|
||||
|
Loading…
x
Reference in New Issue
Block a user