mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-20 13:13:07 +00:00
145 lines
4.5 KiB
C++
145 lines
4.5 KiB
C++
#pragma once
|
|
|
|
#include "llama-impl.h"
|
|
|
|
#include <map>
|
|
|
|
struct llama_vocab;
|
|
|
|
// grammar element type
|
|
enum llama_gretype {
|
|
// end of rule definition
|
|
LLAMA_GRETYPE_END = 0,
|
|
|
|
// start of alternate definition for rule
|
|
LLAMA_GRETYPE_ALT = 1,
|
|
|
|
// non-terminal element: reference to rule
|
|
LLAMA_GRETYPE_RULE_REF = 2,
|
|
|
|
// terminal element: character (code point)
|
|
LLAMA_GRETYPE_CHAR = 3,
|
|
|
|
// inverse char(s) ([^a], [^a-b] [^abc])
|
|
LLAMA_GRETYPE_CHAR_NOT = 4,
|
|
|
|
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
|
// be an inclusive range ([a-z])
|
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
|
|
|
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
|
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
|
LLAMA_GRETYPE_CHAR_ALT = 6,
|
|
|
|
// any character (.)
|
|
LLAMA_GRETYPE_CHAR_ANY = 7,
|
|
};
|
|
|
|
typedef struct llama_grammar_element {
|
|
enum llama_gretype type;
|
|
uint32_t value; // Unicode code point or rule ID
|
|
} llama_grammar_element;
|
|
|
|
struct llama_partial_utf8 {
|
|
uint32_t value; // bit value so far (unshifted)
|
|
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
|
};
|
|
|
|
struct llama_grammar_candidate {
|
|
size_t index;
|
|
const uint32_t * code_points;
|
|
llama_partial_utf8 partial_utf8;
|
|
};
|
|
|
|
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
|
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
|
|
|
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
|
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
|
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
|
|
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
|
|
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
|
// produces the N possible stacks if the given char is accepted at those
|
|
// positions
|
|
void llama_grammar_accept(
|
|
const llama_grammar_rules & rules,
|
|
const llama_grammar_stacks & stacks,
|
|
uint32_t chr,
|
|
llama_grammar_stacks & stacks_new);
|
|
|
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
|
const llama_grammar_rules & rules,
|
|
const llama_grammar_stack & stack,
|
|
const llama_grammar_candidates & candidates);
|
|
|
|
struct llama_grammar_parser {
|
|
std::map<std::string, uint32_t> symbol_ids;
|
|
|
|
llama_grammar_rules rules;
|
|
|
|
llama_grammar_stack c_rules() const;
|
|
|
|
uint32_t get_symbol_id(const char * src, size_t len);
|
|
uint32_t generate_symbol_id(const std::string & base_name);
|
|
|
|
void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
|
|
|
|
const char * parse_alternates(
|
|
const char * src,
|
|
const std::string & rule_name,
|
|
uint32_t rule_id,
|
|
bool is_nested);
|
|
|
|
const char * parse_sequence(
|
|
const char * src,
|
|
const std::string & rule_name,
|
|
llama_grammar_rule & rule,
|
|
bool is_nested);
|
|
|
|
const char * parse_rule(const char * src);
|
|
|
|
bool parse(const char * src);
|
|
void print(FILE * file);
|
|
};
|
|
|
|
struct llama_grammar {
|
|
// note: allow null vocab for testing (not great)
|
|
const llama_vocab * vocab;
|
|
|
|
const llama_grammar_rules rules; // TODO: shared ptr
|
|
llama_grammar_stacks stacks;
|
|
|
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
llama_partial_utf8 partial_utf8;
|
|
};
|
|
|
|
//
|
|
// internal API
|
|
//
|
|
|
|
// note: needed for tests (not great)
|
|
struct llama_grammar * llama_grammar_init_impl(
|
|
const struct llama_vocab * vocab,
|
|
const llama_grammar_element ** rules,
|
|
size_t n_rules,
|
|
size_t start_rule_index);
|
|
|
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
|
|
|
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
|
|
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
|
|
|
|
// TODO: move the API below as member functions of llama_grammar
|
|
void llama_grammar_apply_impl(
|
|
const struct llama_grammar & grammar,
|
|
llama_token_data_array * cur_p);
|
|
|
|
void llama_grammar_accept_impl(
|
|
struct llama_grammar & grammar,
|
|
llama_token token);
|