mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-24 22:56:42 +00:00
talk-llama : sync llama.cpp
This commit is contained in:
parent
234f9bd320
commit
fe18c29ab8
10
Makefile
10
Makefile
@ -1080,10 +1080,12 @@ lsp: examples/lsp/lsp.cpp \
|
|||||||
$(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
||||||
|
|
||||||
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
|
# TODO: disabled until update
|
||||||
$(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
# https://github.com/ggerganov/whisper.cpp/issues/1818
|
||||||
$(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
#talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp \
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
# $(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
||||||
|
# $(CXX) $(CXXFLAGS) $(CFLAGS_SDL) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
|
# $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LDFLAGS_SDL)
|
||||||
|
|
||||||
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
|
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp examples/talk-llama/llama-vocab.cpp examples/talk-llama/llama-grammar.cpp examples/talk-llama/llama-sampling.cpp examples/talk-llama/unicode.cpp examples/talk-llama/unicode-data.cpp \
|
||||||
$(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
$(OBJ_GGML) $(OBJ_WHISPER) $(OBJ_COMMON) $(OBJ_SDL)
|
||||||
|
@ -127,8 +127,10 @@ endif (WHISPER_SDL2)
|
|||||||
add_subdirectory(quantize)
|
add_subdirectory(quantize)
|
||||||
set_target_properties(quantize PROPERTIES FOLDER "examples")
|
set_target_properties(quantize PROPERTIES FOLDER "examples")
|
||||||
if (WHISPER_SDL2)
|
if (WHISPER_SDL2)
|
||||||
add_subdirectory(talk)
|
# TODO: disabled until update
|
||||||
set_target_properties(talk PROPERTIES FOLDER "examples")
|
# https://github.com/ggerganov/whisper.cpp/issues/1818
|
||||||
|
#add_subdirectory(talk)
|
||||||
|
#set_target_properties(talk PROPERTIES FOLDER "examples")
|
||||||
add_subdirectory(talk-llama)
|
add_subdirectory(talk-llama)
|
||||||
set_target_properties(talk-llama PROPERTIES FOLDER "examples")
|
set_target_properties(talk-llama PROPERTIES FOLDER "examples")
|
||||||
add_subdirectory(lsp)
|
add_subdirectory(lsp)
|
||||||
|
@ -3,11 +3,31 @@
|
|||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-sampling.h"
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
//
|
||||||
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
// helpers
|
||||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
//
|
||||||
|
|
||||||
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||||
|
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||||
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||||
|
uint8_t highbits = first_byte >> 4;
|
||||||
|
int len = lookup[highbits];
|
||||||
|
uint8_t mask = (1 << (8 - len)) - 1;
|
||||||
|
uint32_t value = first_byte & mask;
|
||||||
|
const char * end = src + len; // may overrun!
|
||||||
|
const char * pos = src + 1;
|
||||||
|
for ( ; pos < end && *pos; pos++) {
|
||||||
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
|
}
|
||||||
|
return std::make_pair(value, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
const std::string & src,
|
const std::string & src,
|
||||||
llama_partial_utf8 partial_start) {
|
llama_partial_utf8 partial_start) {
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
||||||
@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|||||||
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
static bool is_digit_char(char c) {
|
||||||
return grammar->rules;
|
return '0' <= c && c <= '9';
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
static bool is_word_char(char c) {
|
||||||
return grammar->stacks;
|
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||||
|
const char * pos = src;
|
||||||
|
const char * end = src + size;
|
||||||
|
uint32_t value = 0;
|
||||||
|
for ( ; pos < end && *pos; pos++) {
|
||||||
|
value <<= 4;
|
||||||
|
char c = *pos;
|
||||||
|
if ('a' <= c && c <= 'f') {
|
||||||
|
value += c - 'a' + 10;
|
||||||
|
} else if ('A' <= c && c <= 'F') {
|
||||||
|
value += c - 'A' + 10;
|
||||||
|
} else if ('0' <= c && c <= '9') {
|
||||||
|
value += c - '0';
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != end) {
|
||||||
|
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||||
|
}
|
||||||
|
return std::make_pair(value, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * parse_space(const char * src, bool newline_ok) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||||
|
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||||
|
if (*pos == '#') {
|
||||||
|
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * parse_name(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (is_word_char(*pos)) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos == src) {
|
||||||
|
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * parse_int(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (is_digit_char(*pos)) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos == src) {
|
||||||
|
throw std::runtime_error(std::string("expecting integer at ") + src);
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||||
|
if (*src == '\\') {
|
||||||
|
switch (src[1]) {
|
||||||
|
case 'x': return parse_hex(src + 2, 2);
|
||||||
|
case 'u': return parse_hex(src + 2, 4);
|
||||||
|
case 'U': return parse_hex(src + 2, 8);
|
||||||
|
case 't': return std::make_pair('\t', src + 2);
|
||||||
|
case 'r': return std::make_pair('\r', src + 2);
|
||||||
|
case 'n': return std::make_pair('\n', src + 2);
|
||||||
|
case '\\':
|
||||||
|
case '"':
|
||||||
|
case '[':
|
||||||
|
case ']':
|
||||||
|
return std::make_pair(src[1], src + 2);
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||||
|
}
|
||||||
|
} else if (*src) {
|
||||||
|
return decode_utf8(src);
|
||||||
|
}
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_grammar_char(FILE * file, uint32_t c) {
|
||||||
|
if (0x20 <= c && c <= 0x7f) {
|
||||||
|
fprintf(file, "%c", static_cast<char>(c));
|
||||||
|
} else {
|
||||||
|
// cop out of encoding UTF-8
|
||||||
|
fprintf(file, "<U+%04X>", c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool is_char_element(llama_grammar_element elem) {
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_CHAR: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY: return true;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
|
||||||
|
for (auto elem : rule) {
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
||||||
|
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||||
|
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
|
||||||
|
}
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END:
|
||||||
|
case LLAMA_GRETYPE_ALT:
|
||||||
|
case LLAMA_GRETYPE_RULE_REF:
|
||||||
|
fprintf(file, "(%u) ", elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
fprintf(file, "(\"");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
fprintf(file, "\") ");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_rule(
|
||||||
|
FILE * file,
|
||||||
|
uint32_t rule_id,
|
||||||
|
const llama_grammar_rule & rule,
|
||||||
|
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||||
|
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||||
|
}
|
||||||
|
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||||
|
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||||
|
llama_grammar_element elem = rule[i];
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||||
|
std::to_string(i));
|
||||||
|
case LLAMA_GRETYPE_ALT:
|
||||||
|
fprintf(file, "| ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_RULE_REF:
|
||||||
|
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
fprintf(file, "[");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
fprintf(file, "[^");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||||
|
std::to_string(rule_id) + "," + std::to_string(i));
|
||||||
|
}
|
||||||
|
fprintf(file, "-");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
||||||
|
std::to_string(rule_id) + "," + std::to_string(i));
|
||||||
|
}
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
fprintf(file, ".");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (is_char_element(elem)) {
|
||||||
|
switch (rule[i + 1].type) {
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(file, "] ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// implementation
|
||||||
|
//
|
||||||
|
|
||||||
|
uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
|
||||||
|
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
||||||
|
auto result = symbol_ids.emplace(std::string(src, len), next_id);
|
||||||
|
return result.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
|
||||||
|
uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
|
||||||
|
symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||||
|
return next_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
|
||||||
|
if (rules.size() <= rule_id) {
|
||||||
|
rules.resize(rule_id + 1);
|
||||||
|
}
|
||||||
|
rules[rule_id] = rule;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_grammar_parser::parse_alternates(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint32_t rule_id,
|
||||||
|
bool is_nested) {
|
||||||
|
llama_grammar_rule rule;
|
||||||
|
const char * pos = parse_sequence(src, rule_name, rule, is_nested);
|
||||||
|
while (*pos == '|') {
|
||||||
|
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
pos = parse_sequence(pos, rule_name, rule, is_nested);
|
||||||
|
}
|
||||||
|
rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule(rule_id, rule);
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_grammar_parser::parse_sequence(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
llama_grammar_rule & rule,
|
||||||
|
bool is_nested) {
|
||||||
|
size_t last_sym_start = rule.size();
|
||||||
|
const char * pos = src;
|
||||||
|
|
||||||
|
auto handle_repetitions = [&](int min_times, int max_times) {
|
||||||
|
|
||||||
|
if (last_sym_start == rule.size()) {
|
||||||
|
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||||
|
// the following rewrite rules:
|
||||||
|
// S{m,n} --> S S S (m times) S'(n-m)
|
||||||
|
// S'(x) ::= S S'(x-1) |
|
||||||
|
// (... n-m definitions of these S' rules ...)
|
||||||
|
// S'(1) ::= S |
|
||||||
|
// S{m,} --> S S S (m times) S'
|
||||||
|
// S' ::= S S' |
|
||||||
|
// S* --> S{0,}
|
||||||
|
// --> S' ::= S S' |
|
||||||
|
// S+ --> S{1,}
|
||||||
|
// --> S S'
|
||||||
|
// S' ::= S S' |
|
||||||
|
// S? --> S{0,1}
|
||||||
|
// --> S'
|
||||||
|
// S' ::= S |
|
||||||
|
|
||||||
|
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
|
||||||
|
if (min_times == 0) {
|
||||||
|
rule.resize(last_sym_start);
|
||||||
|
} else {
|
||||||
|
// Repeat the previous elements (min_times - 1) times
|
||||||
|
for (int i = 1; i < min_times; i++) {
|
||||||
|
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t last_rec_rule_id = 0;
|
||||||
|
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
|
||||||
|
|
||||||
|
llama_grammar_rule rec_rule(prev_rule);
|
||||||
|
for (int i = 0; i < n_opt; i++) {
|
||||||
|
rec_rule.resize(prev_rule.size());
|
||||||
|
uint32_t rec_rule_id = generate_symbol_id( rule_name);
|
||||||
|
if (i > 0 || max_times < 0) {
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
|
||||||
|
}
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule( rec_rule_id, rec_rule);
|
||||||
|
last_rec_rule_id = rec_rule_id;
|
||||||
|
}
|
||||||
|
if (n_opt > 0) {
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
while (*pos) {
|
||||||
|
if (*pos == '"') { // literal string
|
||||||
|
pos++;
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
while (*pos != '"') {
|
||||||
|
if (!*pos) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '[') { // char range(s)
|
||||||
|
pos++;
|
||||||
|
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||||
|
if (*pos == '^') {
|
||||||
|
pos++;
|
||||||
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
|
}
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
while (*pos != ']') {
|
||||||
|
if (!*pos) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
enum llama_gretype type = last_sym_start < rule.size()
|
||||||
|
? LLAMA_GRETYPE_CHAR_ALT
|
||||||
|
: start_type;
|
||||||
|
|
||||||
|
rule.push_back({type, char_pair.first});
|
||||||
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
|
if (!pos[1]) {
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
auto endchar_pair = parse_char(pos + 1);
|
||||||
|
pos = endchar_pair.second;
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
|
const char * name_end = parse_name(pos);
|
||||||
|
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
|
||||||
|
pos = parse_space(name_end, is_nested);
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
|
} else if (*pos == '(') { // grouping
|
||||||
|
// parse nested alternates into synthesized rule
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
uint32_t sub_rule_id = generate_symbol_id(rule_name);
|
||||||
|
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
// output reference to synthesized rule
|
||||||
|
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
if (*pos != ')') {
|
||||||
|
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '.') { // any char
|
||||||
|
last_sym_start = rule.size();
|
||||||
|
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '*') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(0, -1);
|
||||||
|
} else if (*pos == '+') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(1, -1);
|
||||||
|
} else if (*pos == '?') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
handle_repetitions(0, 1);
|
||||||
|
} else if (*pos == '{') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
|
||||||
|
if (!is_digit_char(*pos)) {
|
||||||
|
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||||
|
}
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
int min_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = parse_space(int_end, is_nested);
|
||||||
|
|
||||||
|
int max_times = -1;
|
||||||
|
|
||||||
|
if (*pos == '}') {
|
||||||
|
max_times = min_times;
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == ',') {
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
|
||||||
|
if (is_digit_char(*pos)) {
|
||||||
|
const char * int_end = parse_int(pos);
|
||||||
|
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||||
|
pos = parse_space(int_end, is_nested);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (*pos != '}') {
|
||||||
|
throw std::runtime_error(std::string("expecting '}' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(std::string("expecting ',' at ") + pos);
|
||||||
|
}
|
||||||
|
handle_repetitions(min_times, max_times);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * llama_grammar_parser::parse_rule(const char * src) {
|
||||||
|
const char * name_end = parse_name(src);
|
||||||
|
const char * pos = parse_space(name_end, false);
|
||||||
|
size_t name_len = name_end - src;
|
||||||
|
uint32_t rule_id = get_symbol_id(src, name_len);
|
||||||
|
const std::string name(src, name_len);
|
||||||
|
|
||||||
|
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||||
|
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 3, true);
|
||||||
|
|
||||||
|
pos = parse_alternates(pos, name, rule_id, false);
|
||||||
|
|
||||||
|
if (*pos == '\r') {
|
||||||
|
pos += pos[1] == '\n' ? 2 : 1;
|
||||||
|
} else if (*pos == '\n') {
|
||||||
|
pos++;
|
||||||
|
} else if (*pos) {
|
||||||
|
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||||
|
}
|
||||||
|
return parse_space(pos, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_grammar_parser::parse(const char * src) {
|
||||||
|
try {
|
||||||
|
const char * pos = parse_space(src, true);
|
||||||
|
while (*pos) {
|
||||||
|
pos = parse_rule(pos);
|
||||||
|
}
|
||||||
|
// Validate the state to ensure that all rules are defined
|
||||||
|
for (const auto & rule : rules) {
|
||||||
|
if (rule.empty()) {
|
||||||
|
throw std::runtime_error("Undefined rule");
|
||||||
|
}
|
||||||
|
for (const auto & elem : rule) {
|
||||||
|
if (elem.type == LLAMA_GRETYPE_RULE_REF) {
|
||||||
|
// Ensure that the rule at that location exists
|
||||||
|
if (elem.value >= rules.size() || rules[elem.value].empty()) {
|
||||||
|
// Get the name of the rule that is missing
|
||||||
|
for (const auto & kv : symbol_ids) {
|
||||||
|
if (kv.second == elem.value) {
|
||||||
|
throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||||
|
rules.clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_parser::print(FILE * file) {
|
||||||
|
try {
|
||||||
|
std::map<uint32_t, std::string> symbol_id_names;
|
||||||
|
for (const auto & kv : symbol_ids) {
|
||||||
|
symbol_id_names[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
for (size_t i = 0, end = rules.size(); i < end; i++) {
|
||||||
|
// fprintf(file, "%zu: ", i);
|
||||||
|
// print_rule_binary(file, rules[i]);
|
||||||
|
print_rule(file, uint32_t(i), rules[i], symbol_id_names);
|
||||||
|
// fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_stack llama_grammar_parser::c_rules() const {
|
||||||
|
llama_grammar_stack ret;
|
||||||
|
ret.reserve(rules.size());
|
||||||
|
for (const auto & rule : rules) {
|
||||||
|
ret.push_back(rule.data());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns true iff pos points to the end of one of the definitions of a rule
|
// returns true iff pos points to the end of one of the definitions of a rule
|
||||||
@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
|
|||||||
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
||||||
const llama_grammar_element * pos,
|
const llama_grammar_element * pos,
|
||||||
const uint32_t chr) {
|
const uint32_t chr) {
|
||||||
|
|
||||||
bool found = false;
|
bool found = false;
|
||||||
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
||||||
|
|
||||||
@ -225,36 +742,6 @@ static void llama_grammar_advance_stack(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// takes a set of possible pushdown stacks on a grammar, which are required to
|
|
||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
|
||||||
// positions
|
|
||||||
void llama_grammar_accept(
|
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stacks & stacks,
|
|
||||||
const uint32_t chr,
|
|
||||||
llama_grammar_stacks & new_stacks) {
|
|
||||||
new_stacks.clear();
|
|
||||||
|
|
||||||
for (const auto & stack : stacks) {
|
|
||||||
if (stack.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto match = llama_grammar_match_char(stack.back(), chr);
|
|
||||||
if (match.first) {
|
|
||||||
const llama_grammar_element * pos = match.second;
|
|
||||||
|
|
||||||
// update top of stack to next element, if any
|
|
||||||
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
|
||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
|
||||||
new_stack.push_back(pos);
|
|
||||||
}
|
|
||||||
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_grammar_candidates llama_grammar_reject_candidates(
|
static llama_grammar_candidates llama_grammar_reject_candidates(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
@ -270,9 +757,98 @@ static llama_grammar_candidates llama_grammar_reject_candidates(
|
|||||||
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
||||||
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
||||||
}
|
}
|
||||||
|
|
||||||
return rejects;
|
return rejects;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_grammar_detect_left_recursion(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
size_t rule_index,
|
||||||
|
std::vector<bool> * rules_visited,
|
||||||
|
std::vector<bool> * rules_in_progress,
|
||||||
|
std::vector<bool> * rules_may_be_empty) {
|
||||||
|
if ((*rules_in_progress)[rule_index]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = true;
|
||||||
|
|
||||||
|
const llama_grammar_rule & rule = rules[rule_index];
|
||||||
|
|
||||||
|
// First check if the rule might produce the empty string. This could be done combined with the second
|
||||||
|
// step but it's more readable as two steps.
|
||||||
|
bool at_rule_start = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
if (at_rule_start) {
|
||||||
|
(*rules_may_be_empty)[rule_index] = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
at_rule_start = true;
|
||||||
|
} else {
|
||||||
|
at_rule_start = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
||||||
|
// be empty)
|
||||||
|
bool recurse_into_nonterminal = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
||||||
|
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
recurse_into_nonterminal = true;
|
||||||
|
} else {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = false;
|
||||||
|
(*rules_visited)[rule_index] = true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
||||||
|
return grammar->rules;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
||||||
|
return grammar->stacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stacks & stacks,
|
||||||
|
const uint32_t chr,
|
||||||
|
llama_grammar_stacks & stacks_new) {
|
||||||
|
stacks_new.clear();
|
||||||
|
stacks_new.reserve(stacks.size());
|
||||||
|
|
||||||
|
for (const auto & stack : stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||||
|
if (match.first) {
|
||||||
|
const llama_grammar_element * pos = match.second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
new_stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
@ -328,63 +904,10 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||||||
return rejects;
|
return rejects;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_grammar_detect_left_recursion(
|
////////////////////
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
size_t rule_index,
|
|
||||||
std::vector<bool> * rules_visited,
|
|
||||||
std::vector<bool> * rules_in_progress,
|
|
||||||
std::vector<bool> * rules_may_be_empty) {
|
|
||||||
if ((*rules_in_progress)[rule_index]) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
(*rules_in_progress)[rule_index] = true;
|
|
||||||
|
|
||||||
const llama_grammar_rule & rule = rules[rule_index];
|
|
||||||
|
|
||||||
// First check if the rule might produce the empty string. This could be done combined with the second
|
|
||||||
// step but it's more readable as two steps.
|
|
||||||
bool at_rule_start = true;
|
|
||||||
for (size_t i = 0; i < rule.size(); i++) {
|
|
||||||
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
|
||||||
if (at_rule_start) {
|
|
||||||
(*rules_may_be_empty)[rule_index] = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
at_rule_start = true;
|
|
||||||
} else {
|
|
||||||
at_rule_start = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
|
||||||
// be empty)
|
|
||||||
bool recurse_into_nonterminal = true;
|
|
||||||
for (size_t i = 0; i < rule.size(); i++) {
|
|
||||||
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
|
||||||
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
|
||||||
recurse_into_nonterminal = false;
|
|
||||||
}
|
|
||||||
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
|
||||||
recurse_into_nonterminal = true;
|
|
||||||
} else {
|
|
||||||
recurse_into_nonterminal = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(*rules_in_progress)[rule_index] = false;
|
|
||||||
(*rules_visited)[rule_index] = true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// grammar - external
|
|
||||||
//
|
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index) {
|
size_t start_rule_index) {
|
||||||
@ -438,22 +961,104 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||||
|
llama_grammar_parser parser;
|
||||||
|
|
||||||
|
// if there is a grammar, parse it
|
||||||
|
if (!parser.parse(grammar_str)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (parser.rules.empty()) {
|
||||||
|
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that there is a "root" node.
|
||||||
|
if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
|
||||||
|
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
|
||||||
|
|
||||||
|
const size_t n_rules = grammar_rules.size();
|
||||||
|
const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
|
||||||
|
|
||||||
|
const llama_grammar_element * pos;
|
||||||
|
|
||||||
|
// copy rule definitions into vectors
|
||||||
|
llama_grammar_rules vec_rules(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
||||||
|
vec_rules[i].push_back(*pos);
|
||||||
|
}
|
||||||
|
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for left recursion
|
||||||
|
std::vector<bool> rules_visited(n_rules);
|
||||||
|
std::vector<bool> rules_in_progress(n_rules);
|
||||||
|
std::vector<bool> rules_may_be_empty(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
if (rules_visited[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
||||||
|
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over alternates of start rule to build initial stacks
|
||||||
|
llama_grammar_stacks stacks;
|
||||||
|
pos = vec_rules[start_rule_index].data();
|
||||||
|
do {
|
||||||
|
llama_grammar_stack stack;
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
pos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
|
||||||
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
|
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
|
if (grammar == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
delete grammar;
|
delete grammar;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
|
||||||
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
|
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
|
||||||
|
|
||||||
// redirect elements in stacks to point to new rules
|
// redirect elements in stacks to point to new rules
|
||||||
for (size_t is = 0; is < result->stacks.size(); is++) {
|
for (size_t is = 0; is < result->stacks.size(); is++) {
|
||||||
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
||||||
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
|
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
|
||||||
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
|
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
|
||||||
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
|
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
|
||||||
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
|
||||||
GGML_ASSERT(grammar);
|
GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
GGML_ASSERT(vocab);
|
|
||||||
|
|
||||||
int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
bool allow_eog = false;
|
bool allow_eog = false;
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
allow_eog = true;
|
allow_eog = true;
|
||||||
break;
|
break;
|
||||||
@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||||
candidates_decoded.reserve(candidates->size);
|
candidates_decoded.reserve(cur_p->size);
|
||||||
|
|
||||||
llama_grammar_candidates candidates_grammar;
|
llama_grammar_candidates candidates_grammar;
|
||||||
candidates_grammar.reserve(candidates->size);
|
candidates_grammar.reserve(cur_p->size);
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
const llama_token id = candidates->data[i].id;
|
const llama_token id = cur_p->data[i].id;
|
||||||
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
|
||||||
|
|
||||||
if (llama_token_is_eog_impl(*vocab, id)) {
|
if (llama_token_is_eog_impl(*grammar.vocab, id)) {
|
||||||
if (!allow_eog) {
|
if (!allow_eog) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
} else if (piece.empty() || piece[0] == 0) {
|
} else if (piece.empty() || piece[0] == 0) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
cur_p->data[i].logit = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
|
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
|
||||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
||||||
for (const auto & reject : rejects) {
|
for (const auto & reject : rejects) {
|
||||||
candidates->data[reject.index].logit = -INFINITY;
|
cur_p->data[reject.index].logit = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
GGML_ASSERT(grammar.vocab != nullptr);
|
||||||
|
|
||||||
if (llama_token_is_eog_impl(*vocab, token)) {
|
if (llama_token_is_eog_impl(*grammar.vocab, token)) {
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : grammar.stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
|
||||||
|
|
||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
llama_grammar_stacks tmp_new_stacks;
|
llama_grammar_stacks stacks_new;
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
||||||
grammar->stacks = tmp_new_stacks;
|
grammar.stacks = std::move(stacks_new);
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar->partial_utf8 = decoded.second;
|
grammar.partial_utf8 = decoded.second;
|
||||||
GGML_ASSERT(!grammar->stacks.empty());
|
GGML_ASSERT(!grammar.stacks.empty());
|
||||||
|
|
||||||
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,115 @@
|
|||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
struct llama_sampling;
|
|
||||||
|
// grammar element type
|
||||||
|
enum llama_gretype {
|
||||||
|
// end of rule definition
|
||||||
|
LLAMA_GRETYPE_END = 0,
|
||||||
|
|
||||||
|
// start of alternate definition for rule
|
||||||
|
LLAMA_GRETYPE_ALT = 1,
|
||||||
|
|
||||||
|
// non-terminal element: reference to rule
|
||||||
|
LLAMA_GRETYPE_RULE_REF = 2,
|
||||||
|
|
||||||
|
// terminal element: character (code point)
|
||||||
|
LLAMA_GRETYPE_CHAR = 3,
|
||||||
|
|
||||||
|
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||||
|
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||||
|
// be an inclusive range ([a-z])
|
||||||
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||||
|
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||||
|
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||||
|
|
||||||
|
// any character (.)
|
||||||
|
LLAMA_GRETYPE_CHAR_ANY = 7,
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct llama_grammar_element {
|
||||||
|
enum llama_gretype type;
|
||||||
|
uint32_t value; // Unicode code point or rule ID
|
||||||
|
} llama_grammar_element;
|
||||||
|
|
||||||
|
struct llama_partial_utf8 {
|
||||||
|
uint32_t value; // bit value so far (unshifted)
|
||||||
|
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_grammar_candidate {
|
||||||
|
size_t index;
|
||||||
|
const uint32_t * code_points;
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||||
|
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||||
|
|
||||||
|
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||||
|
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||||
|
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
|
// positions
|
||||||
|
void llama_grammar_accept(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stacks & stacks,
|
||||||
|
uint32_t chr,
|
||||||
|
llama_grammar_stacks & stacks_new);
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
const llama_grammar_candidates & candidates);
|
||||||
|
|
||||||
|
struct llama_grammar_parser {
|
||||||
|
std::map<std::string, uint32_t> symbol_ids;
|
||||||
|
|
||||||
|
llama_grammar_rules rules;
|
||||||
|
|
||||||
|
llama_grammar_stack c_rules() const;
|
||||||
|
|
||||||
|
uint32_t get_symbol_id(const char * src, size_t len);
|
||||||
|
uint32_t generate_symbol_id(const std::string & base_name);
|
||||||
|
|
||||||
|
void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
|
||||||
|
|
||||||
|
const char * parse_alternates(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint32_t rule_id,
|
||||||
|
bool is_nested);
|
||||||
|
|
||||||
|
const char * parse_sequence(
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
llama_grammar_rule & rule,
|
||||||
|
bool is_nested);
|
||||||
|
|
||||||
|
const char * parse_rule(const char * src);
|
||||||
|
|
||||||
|
bool parse(const char * src);
|
||||||
|
void print(FILE * file);
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
const llama_grammar_rules rules;
|
// note: allow null vocab for testing (not great)
|
||||||
|
const llama_vocab * vocab;
|
||||||
|
|
||||||
|
const llama_grammar_rules rules; // TODO: shared ptr
|
||||||
llama_grammar_stacks stacks;
|
llama_grammar_stacks stacks;
|
||||||
|
|
||||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
@ -17,23 +121,24 @@ struct llama_grammar {
|
|||||||
// internal API
|
// internal API
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// note: needed for tests (not great)
|
||||||
struct llama_grammar * llama_grammar_init_impl(
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index);
|
size_t start_rule_index);
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
|
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
|
||||||
|
|
||||||
void llama_grammar_sample_impl(
|
// TODO: move the API below as member functions of llama_grammar
|
||||||
const struct llama_grammar * grammar,
|
void llama_grammar_apply_impl(
|
||||||
const struct llama_vocab * vocab,
|
const struct llama_grammar & grammar,
|
||||||
const struct llama_sampling * smpl,
|
llama_token_data_array * cur_p);
|
||||||
llama_token_data_array * candidates);
|
|
||||||
|
|
||||||
void llama_grammar_accept_token_impl(
|
void llama_grammar_accept_impl(
|
||||||
struct llama_grammar * grammar,
|
struct llama_grammar & grammar,
|
||||||
const struct llama_vocab * vocab,
|
|
||||||
const struct llama_sampling * smpl,
|
|
||||||
llama_token token);
|
llama_token token);
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#define LLAMA_API_INTERNAL
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
#ifdef __GNUC__
|
#ifdef __GNUC__
|
||||||
#ifdef __MINGW32__
|
#ifdef __MINGW32__
|
||||||
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
@ -21,14 +24,31 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3)
|
|||||||
void llama_log_internal (ggml_log_level level, const char * format, ...);
|
void llama_log_internal (ggml_log_level level, const char * format, ...);
|
||||||
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
|
#define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
|
||||||
|
|
||||||
//
|
//
|
||||||
// helpers
|
// helpers
|
||||||
//
|
//
|
||||||
|
|
||||||
|
struct time_meas {
|
||||||
|
time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||||
|
|
||||||
|
~time_meas() {
|
||||||
|
if (t_start_us >= 0) {
|
||||||
|
t_acc += ggml_time_us() - t_start_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||||
if (search.empty()) {
|
if (search.empty()) {
|
||||||
return;
|
return;
|
||||||
@ -45,3 +65,117 @@ static void replace_all(std::string & s, const std::string & search, const std::
|
|||||||
builder.append(s, last_pos, std::string::npos);
|
builder.append(s, last_pos, std::string::npos);
|
||||||
s = std::move(builder);
|
s = std::move(builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||||
|
struct llama_context * ctx
|
||||||
|
);
|
||||||
|
|
||||||
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
|
template<typename T>
|
||||||
|
struct ring_buffer {
|
||||||
|
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||||
|
|
||||||
|
T & front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & front() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
T & back() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & back() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(const T & value) {
|
||||||
|
if (capacity == 0) {
|
||||||
|
throw std::runtime_error("ring buffer: capacity is zero");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sz == capacity) {
|
||||||
|
// advance the start when buffer is full
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
} else {
|
||||||
|
sz++;
|
||||||
|
}
|
||||||
|
data[pos] = value;
|
||||||
|
pos = (pos + 1) % capacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
T pop_front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
T value = data[first];
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
sz--;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
//T & operator[](size_t i) {
|
||||||
|
// if (i >= sz) {
|
||||||
|
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
// }
|
||||||
|
// return data[(first + i) % capacity];
|
||||||
|
//}
|
||||||
|
|
||||||
|
//const T & at(size_t i) const {
|
||||||
|
// if (i >= sz) {
|
||||||
|
// throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
// }
|
||||||
|
// return data[(first + i) % capacity];
|
||||||
|
//}
|
||||||
|
|
||||||
|
const T & rat(size_t i) const {
|
||||||
|
if (i >= sz) {
|
||||||
|
throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
}
|
||||||
|
return data[(first + sz - i - 1) % capacity];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<T> to_vector() const {
|
||||||
|
std::vector<T> result;
|
||||||
|
result.reserve(sz);
|
||||||
|
for (size_t i = 0; i < sz; i++) {
|
||||||
|
result.push_back(data[(first + i) % capacity]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
// here only reset the status of the buffer
|
||||||
|
sz = 0;
|
||||||
|
first = 0;
|
||||||
|
pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return sz == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return sz;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t capacity = 0;
|
||||||
|
size_t sz = 0;
|
||||||
|
size_t first = 0;
|
||||||
|
size_t pos = 0;
|
||||||
|
std::vector<T> data;
|
||||||
|
};
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,56 +1,29 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama-impl.h"
|
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
|
||||||
|
|
||||||
struct llama_sampling {
|
#include "llama-grammar.h"
|
||||||
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
|
||||||
|
|
||||||
std::mt19937 rng;
|
#include <unordered_map>
|
||||||
|
|
||||||
int32_t n_vocab = 0;
|
struct llama_vocab;
|
||||||
|
struct llama_grammar;
|
||||||
|
|
||||||
mutable int64_t t_sample_us = 0;
|
// sampler chain
|
||||||
mutable int32_t n_sample = 0;
|
|
||||||
|
|
||||||
void reset_timings() const {
|
struct llama_sampler_chain {
|
||||||
t_sample_us = 0;
|
llama_sampler_chain_params params;
|
||||||
n_sample = 0;
|
|
||||||
}
|
std::vector<struct llama_sampler *> samplers;
|
||||||
|
|
||||||
|
// timing
|
||||||
|
|
||||||
|
mutable int64_t t_sample_us;
|
||||||
|
|
||||||
|
mutable int32_t n_sample;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||||
// internal API
|
const struct llama_vocab & vocab,
|
||||||
//
|
const char * grammar_str,
|
||||||
|
const char * grammar_root);
|
||||||
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
|
|
||||||
|
|
||||||
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
|
||||||
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
|
||||||
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
|
||||||
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
|
||||||
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
|
||||||
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
|
||||||
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
|
||||||
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
|
|
||||||
|
|
||||||
void llama_sample_repetition_penalties_impl(
|
|
||||||
struct llama_sampling * smpl,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const llama_token * last_tokens,
|
|
||||||
size_t penalty_last_n,
|
|
||||||
float penalty_repeat,
|
|
||||||
float penalty_freq,
|
|
||||||
float penalty_present);
|
|
||||||
|
|
||||||
void llama_sample_apply_guidance_impl(
|
|
||||||
struct llama_sampling * smpl,
|
|
||||||
float * logits,
|
|
||||||
float * logits_guidance,
|
|
||||||
float scale);
|
|
||||||
|
|
||||||
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
|
||||||
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
|
||||||
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
|
||||||
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
|
||||||
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
|
||||||
|
|
||||||
|
@ -58,17 +58,17 @@ struct naive_trie {
|
|||||||
auto res = children.find(c);
|
auto res = children.find(c);
|
||||||
if (res != children.end()) {
|
if (res != children.end()) {
|
||||||
return res->second.get_longest_prefix(key, len, offset + 1);
|
return res->second.get_longest_prefix(key, len, offset + 1);
|
||||||
} else {
|
}
|
||||||
|
|
||||||
return std::make_pair(key, offset);
|
return std::make_pair(key, offset);
|
||||||
}
|
}
|
||||||
}
|
const struct naive_trie * traverse(const char c) const {
|
||||||
struct naive_trie * traverse(const char c) {
|
|
||||||
auto res = children.find(c);
|
auto res = children.find(c);
|
||||||
if (res != children.end()) {
|
if (res != children.end()) {
|
||||||
return &res->second;
|
return &res->second;
|
||||||
} else {
|
|
||||||
return NULL;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return NULL;
|
||||||
}
|
}
|
||||||
std::map<char, struct naive_trie> children;
|
std::map<char, struct naive_trie> children;
|
||||||
bool has_value;
|
bool has_value;
|
||||||
@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
|
|||||||
// traverse the token matcher trie to find a matching token
|
// traverse the token matcher trie to find a matching token
|
||||||
bool single_codepoint_token_found = false;
|
bool single_codepoint_token_found = false;
|
||||||
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
||||||
struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
||||||
|
|
||||||
while (prefix_offset <= input_len && node != NULL) {
|
while (prefix_offset <= input_len && node != NULL) {
|
||||||
// check if we found valid token in prefix
|
// check if we found valid token in prefix
|
||||||
@ -963,7 +963,7 @@ private:
|
|||||||
/*
|
/*
|
||||||
* This structure is a view wrapper for XOR-compressed double array (XCDA)
|
* This structure is a view wrapper for XOR-compressed double array (XCDA)
|
||||||
* See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
|
* See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries.
|
||||||
* Eeach bit-packed entry contains:
|
* Each bit-packed entry contains:
|
||||||
* - BASE array value in bits 10-30
|
* - BASE array value in bits 10-30
|
||||||
* - LCHECK array value in bits 0-7
|
* - LCHECK array value in bits 0-7
|
||||||
* - LEAF array value in bit 9
|
* - LEAF array value in bit 9
|
||||||
@ -1097,6 +1097,111 @@ private:
|
|||||||
struct naive_trie token_matcher;
|
struct naive_trie token_matcher;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// RWKV tokenizer
|
||||||
|
//
|
||||||
|
|
||||||
|
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
output.reserve(escaped.size());
|
||||||
|
|
||||||
|
// Parser state
|
||||||
|
bool escaping = false;
|
||||||
|
uint8_t hex_remaining = 0;
|
||||||
|
uint8_t hex_acc = 0;
|
||||||
|
|
||||||
|
// Step through characters, performing parsing
|
||||||
|
for (const char & c : escaped) {
|
||||||
|
// If we're parsing a hex code, interpret the next character
|
||||||
|
if (hex_remaining != 0) {
|
||||||
|
uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
|
||||||
|
hex_acc = (hex_acc << 4) + value;
|
||||||
|
|
||||||
|
hex_remaining -= 1;
|
||||||
|
if (hex_remaining == 0) {
|
||||||
|
output.push_back(hex_acc);
|
||||||
|
hex_acc = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we got an escape character, interpret it
|
||||||
|
if (escaping) {
|
||||||
|
if (c == 't') {
|
||||||
|
output.push_back('\t');
|
||||||
|
} else if (c == 'n') {
|
||||||
|
output.push_back('\n');
|
||||||
|
} else if (c == 'r') {
|
||||||
|
output.push_back('\r');
|
||||||
|
} else if (c == 'x') {
|
||||||
|
hex_remaining = 2;
|
||||||
|
} else {
|
||||||
|
output.push_back(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
escaping = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c == '\\') {
|
||||||
|
escaping = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
output.push_back(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llm_tokenizer_rwkv {
|
||||||
|
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
||||||
|
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
||||||
|
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||||
|
|
||||||
|
// build trie
|
||||||
|
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
|
||||||
|
const auto & token = vocab.id_to_token[id];
|
||||||
|
const auto data = llama_unescape_rwkv_token(token.text);
|
||||||
|
token_matcher.insert((const char *) data.data(), data.size(), id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
|
uint32_t position = 0;
|
||||||
|
|
||||||
|
while (position < text.size()) {
|
||||||
|
const struct naive_trie * node = token_matcher.traverse(text[position]);
|
||||||
|
if (node == NULL) {
|
||||||
|
// no matching token found, add unknown token
|
||||||
|
output.push_back(vocab.special_unk_id);
|
||||||
|
position += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// traverse the trie to find the longest matching token
|
||||||
|
uint32_t token_id = 0;
|
||||||
|
uint32_t token_length = 0;
|
||||||
|
while (node != NULL) {
|
||||||
|
if (node->has_value) {
|
||||||
|
token_id = node->value;
|
||||||
|
token_length = position + 1;
|
||||||
|
}
|
||||||
|
node = node->traverse(text[++position]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the longest matching token
|
||||||
|
output.push_back(token_id);
|
||||||
|
position = token_length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_vocab & vocab;
|
||||||
|
|
||||||
|
struct naive_trie token_matcher;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// (de-) tokenize
|
// (de-) tokenize
|
||||||
//
|
//
|
||||||
@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
output.push_back(vocab.special_eos_id);
|
output.push_back(vocab.special_eos_id);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_VOCAB_TYPE_RWKV:
|
||||||
|
{
|
||||||
|
for (const auto & fragment : fragment_buffer) {
|
||||||
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
|
||||||
|
#ifdef PRETOKENIZERDEBUG
|
||||||
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
llm_tokenizer_rwkv tokenizer(vocab);
|
||||||
|
tokenizer.tokenize(raw_text, output);
|
||||||
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
|
output.push_back(fragment.token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_NONE:
|
case LLAMA_VOCAB_TYPE_NONE:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
@ -1448,11 +1570,7 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
|
||||||
return token != -1 && (
|
return token != -1 && vocab.special_eog_ids.count(token) > 0;
|
||||||
token == llama_token_eos_impl(vocab) ||
|
|
||||||
token == llama_token_eot_impl(vocab) ||
|
|
||||||
token == llama_token_eom_impl(vocab)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
|
||||||
@ -1616,6 +1734,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case LLAMA_VOCAB_TYPE_RWKV: {
|
||||||
|
std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
|
||||||
|
|
||||||
|
// If we don't have enough space, return an error
|
||||||
|
if (result.size() > (size_t)length) {
|
||||||
|
return -(int)result.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(buf, result.data(), result.size());
|
||||||
|
return (int)result.size();
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
struct llama_vocab {
|
struct llama_vocab {
|
||||||
using id = llama_token;
|
using id = llama_token;
|
||||||
@ -18,6 +19,8 @@ struct llama_vocab {
|
|||||||
tattr attr;
|
tattr attr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
|
||||||
|
|
||||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
|
||||||
@ -47,6 +50,9 @@ struct llama_vocab {
|
|||||||
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
||||||
id special_eom_id = -1;
|
id special_eom_id = -1;
|
||||||
|
|
||||||
|
// set of all tokens that cause "end of generation"
|
||||||
|
std::set<id> special_eog_ids;
|
||||||
|
|
||||||
// tokenizer flags
|
// tokenizer flags
|
||||||
bool tokenizer_add_space_prefix = false;
|
bool tokenizer_add_space_prefix = false;
|
||||||
bool tokenizer_add_bos = false;
|
bool tokenizer_add_bos = false;
|
||||||
@ -62,8 +68,6 @@ struct llama_vocab {
|
|||||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// internal API
|
// internal API
|
||||||
//
|
//
|
||||||
@ -76,6 +80,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
|
|||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special = false);
|
bool parse_special = false);
|
||||||
|
|
||||||
|
// TODO: move the API below as member functions of llama_vocab
|
||||||
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
||||||
|
|
||||||
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -33,12 +33,15 @@
|
|||||||
|
|
||||||
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
||||||
|
|
||||||
|
// TODO: use everywhere in the implementation
|
||||||
|
#define LLAMA_TOKEN_NULL -1
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 8
|
#define LLAMA_SESSION_VERSION 9
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 2
|
#define LLAMA_STATE_SEQ_VERSION 2
|
||||||
@ -53,8 +56,10 @@ extern "C" {
|
|||||||
// TODO: show sample usage
|
// TODO: show sample usage
|
||||||
//
|
//
|
||||||
|
|
||||||
|
// struct llama_vocab; // TODO: add in the future
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
struct llama_sampler;
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
typedef int32_t llama_token;
|
typedef int32_t llama_token;
|
||||||
@ -66,6 +71,7 @@ extern "C" {
|
|||||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||||
|
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||||
};
|
};
|
||||||
|
|
||||||
// pre-tokenization types
|
// pre-tokenization types
|
||||||
@ -166,6 +172,8 @@ extern "C" {
|
|||||||
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
|
||||||
|
|
||||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
@ -198,6 +206,7 @@ extern "C" {
|
|||||||
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
@ -205,8 +214,10 @@ extern "C" {
|
|||||||
} llama_token_data;
|
} llama_token_data;
|
||||||
|
|
||||||
typedef struct llama_token_data_array {
|
typedef struct llama_token_data_array {
|
||||||
|
// TODO: consider SoA
|
||||||
llama_token_data * data;
|
llama_token_data * data;
|
||||||
size_t size;
|
size_t size;
|
||||||
|
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
||||||
bool sorted;
|
bool sorted;
|
||||||
} llama_token_data_array;
|
} llama_token_data_array;
|
||||||
|
|
||||||
@ -267,9 +278,9 @@ extern "C" {
|
|||||||
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
|
||||||
|
|
||||||
// main_gpu interpretation depends on split_mode:
|
// main_gpu interpretation depends on split_mode:
|
||||||
// LLAMA_SPLIT_NONE: the GPU that is used for the entire model
|
// LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model
|
||||||
// LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
|
// LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results
|
||||||
// LLAMA_SPLIT_LAYER: ignored
|
// LLAMA_SPLIT_MODE_LAYER: ignored
|
||||||
int32_t main_gpu;
|
int32_t main_gpu;
|
||||||
|
|
||||||
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
||||||
@ -299,13 +310,12 @@ extern "C" {
|
|||||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||||
// https://github.com/ggerganov/llama.cpp/pull/7544
|
// https://github.com/ggerganov/llama.cpp/pull/7544
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
uint32_t seed; // RNG seed, -1 for random
|
|
||||||
uint32_t n_ctx; // text context, 0 = from model
|
uint32_t n_ctx; // text context, 0 = from model
|
||||||
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
|
||||||
uint32_t n_ubatch; // physical maximum batch size
|
uint32_t n_ubatch; // physical maximum batch size
|
||||||
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
|
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
|
||||||
uint32_t n_threads; // number of threads to use for generation
|
int32_t n_threads; // number of threads to use for generation
|
||||||
uint32_t n_threads_batch; // number of threads to use for batch processing
|
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||||
|
|
||||||
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||||
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
|
||||||
@ -327,11 +337,13 @@ extern "C" {
|
|||||||
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
|
||||||
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
|
||||||
|
|
||||||
// Keep the booleans together to avoid misalignment during copy-by-value.
|
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
|
||||||
|
// TODO: move at the end of the struct
|
||||||
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
|
||||||
bool embeddings; // if true, extract embeddings (together with logits)
|
bool embeddings; // if true, extract embeddings (together with logits)
|
||||||
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
|
||||||
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
|
||||||
|
bool no_perf; // whether to measure performance timings
|
||||||
|
|
||||||
// Abort callback
|
// Abort callback
|
||||||
// if it returns true, execution of llama_decode() will be aborted
|
// if it returns true, execution of llama_decode() will be aborted
|
||||||
@ -355,56 +367,14 @@ extern "C" {
|
|||||||
void * kv_overrides; // pointer to vector containing overrides
|
void * kv_overrides; // pointer to vector containing overrides
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
// grammar types
|
typedef struct llama_logit_bias {
|
||||||
struct llama_grammar;
|
llama_token token;
|
||||||
|
float bias;
|
||||||
|
} llama_logit_bias;
|
||||||
|
|
||||||
// grammar element type
|
typedef struct llama_sampler_chain_params {
|
||||||
enum llama_gretype {
|
bool no_perf; // whether to measure performance timings
|
||||||
// end of rule definition
|
} llama_sampler_chain_params;
|
||||||
LLAMA_GRETYPE_END = 0,
|
|
||||||
|
|
||||||
// start of alternate definition for rule
|
|
||||||
LLAMA_GRETYPE_ALT = 1,
|
|
||||||
|
|
||||||
// non-terminal element: reference to rule
|
|
||||||
LLAMA_GRETYPE_RULE_REF = 2,
|
|
||||||
|
|
||||||
// terminal element: character (code point)
|
|
||||||
LLAMA_GRETYPE_CHAR = 3,
|
|
||||||
|
|
||||||
// inverse char(s) ([^a], [^a-b] [^abc])
|
|
||||||
LLAMA_GRETYPE_CHAR_NOT = 4,
|
|
||||||
|
|
||||||
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
|
||||||
// be an inclusive range ([a-z])
|
|
||||||
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
|
||||||
|
|
||||||
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
|
||||||
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
|
||||||
LLAMA_GRETYPE_CHAR_ALT = 6,
|
|
||||||
|
|
||||||
// any character (.)
|
|
||||||
LLAMA_GRETYPE_CHAR_ANY = 7,
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct llama_grammar_element {
|
|
||||||
enum llama_gretype type;
|
|
||||||
uint32_t value; // Unicode code point or rule ID
|
|
||||||
} llama_grammar_element;
|
|
||||||
|
|
||||||
// performance timing information
|
|
||||||
struct llama_timings {
|
|
||||||
double t_start_ms;
|
|
||||||
double t_end_ms;
|
|
||||||
double t_load_ms;
|
|
||||||
double t_sample_ms;
|
|
||||||
double t_p_eval_ms;
|
|
||||||
double t_eval_ms;
|
|
||||||
|
|
||||||
int32_t n_sample;
|
|
||||||
int32_t n_p_eval;
|
|
||||||
int32_t n_eval;
|
|
||||||
};
|
|
||||||
|
|
||||||
// used in chat template
|
// used in chat template
|
||||||
typedef struct llama_chat_message {
|
typedef struct llama_chat_message {
|
||||||
@ -416,8 +386,10 @@ extern "C" {
|
|||||||
struct llama_lora_adapter;
|
struct llama_lora_adapter;
|
||||||
|
|
||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
|
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
|
LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
|
||||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||||
|
|
||||||
// Initialize the llama + ggml backend
|
// Initialize the llama + ggml backend
|
||||||
@ -428,6 +400,13 @@ extern "C" {
|
|||||||
//optional:
|
//optional:
|
||||||
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
|
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
|
||||||
|
|
||||||
|
// Optional: an auto threadpool gets created in ggml if not passed explicitly
|
||||||
|
LLAMA_API void llama_attach_threadpool(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
ggml_threadpool_t threadpool,
|
||||||
|
ggml_threadpool_t threadpool_batch);
|
||||||
|
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
|
||||||
|
|
||||||
// Call once at the end of the program - currently only used for MPI
|
// Call once at the end of the program - currently only used for MPI
|
||||||
LLAMA_API void llama_backend_free(void);
|
LLAMA_API void llama_backend_free(void);
|
||||||
|
|
||||||
@ -437,6 +416,7 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API void llama_free_model(struct llama_model * model);
|
LLAMA_API void llama_free_model(struct llama_model * model);
|
||||||
|
|
||||||
|
// TODO: rename to llama_init_from_model
|
||||||
LLAMA_API struct llama_context * llama_new_context_with_model(
|
LLAMA_API struct llama_context * llama_new_context_with_model(
|
||||||
struct llama_model * model,
|
struct llama_model * model,
|
||||||
struct llama_context_params params);
|
struct llama_context_params params);
|
||||||
@ -452,22 +432,22 @@ extern "C" {
|
|||||||
LLAMA_API bool llama_supports_mlock (void);
|
LLAMA_API bool llama_supports_mlock (void);
|
||||||
LLAMA_API bool llama_supports_gpu_offload(void);
|
LLAMA_API bool llama_supports_gpu_offload(void);
|
||||||
|
|
||||||
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
|
||||||
|
|
||||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
|
||||||
|
|
||||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
|
||||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
|
||||||
|
|
||||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
|
||||||
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
|
||||||
|
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
|
||||||
|
|
||||||
|
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||||
|
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||||
|
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||||
|
|
||||||
// Get the model's RoPE frequency scaling factor
|
// Get the model's RoPE frequency scaling factor
|
||||||
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
|
||||||
@ -696,7 +676,7 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
|
|
||||||
// Returns the *actual* size in bytes of the state
|
// Returns the *actual* size in bytes of the state
|
||||||
// (rng, logits, embedding and kv_cache)
|
// (logits, embedding and kv_cache)
|
||||||
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
|
||||||
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
|
||||||
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
|
||||||
@ -837,13 +817,13 @@ extern "C" {
|
|||||||
// Set the number of threads used for decoding
|
// Set the number of threads used for decoding
|
||||||
// n_threads is the number of threads used for generation (single token)
|
// n_threads is the number of threads used for generation (single token)
|
||||||
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
||||||
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
|
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch);
|
||||||
|
|
||||||
// Get the number of threads used for generation of a single token.
|
// Get the number of threads used for generation of a single token.
|
||||||
LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
|
LLAMA_API int32_t llama_n_threads(struct llama_context * ctx);
|
||||||
|
|
||||||
// Get the number of threads used for prompt and batch processing (multiple token).
|
// Get the number of threads used for prompt and batch processing (multiple token).
|
||||||
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
|
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
|
||||||
|
|
||||||
// Set whether the model is in embeddings mode or not
|
// Set whether the model is in embeddings mode or not
|
||||||
// If true, embeddings will be returned but logits will not
|
// If true, embeddings will be returned but logits will not
|
||||||
@ -999,121 +979,114 @@ extern "C" {
|
|||||||
int32_t length);
|
int32_t length);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Grammar
|
// Sampling API
|
||||||
|
//
|
||||||
|
// Sample usage:
|
||||||
|
//
|
||||||
|
// // prepare the sampling chain at the start
|
||||||
|
// auto sparams = llama_sampler_chain_default_params();
|
||||||
|
//
|
||||||
|
// llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
//
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
|
||||||
|
//
|
||||||
|
// // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
|
||||||
|
// // this sampler will be responsible to select the actual token
|
||||||
|
// llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
|
||||||
|
//
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// // decoding loop:
|
||||||
|
// while (...) {
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// llama_decode(ctx, batch);
|
||||||
|
//
|
||||||
|
// // sample from the logits of the last token in the batch
|
||||||
|
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
|
||||||
|
//
|
||||||
|
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
|
||||||
|
// llama_sampler_accept(smpl, id);
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// llama_sampler_free(smpl);
|
||||||
|
//
|
||||||
|
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
|
||||||
|
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
|
||||||
//
|
//
|
||||||
|
|
||||||
/// Initialize a llama_grammar.
|
typedef void * llama_sampler_context_t;
|
||||||
///
|
|
||||||
/// @param rules The rule elements of the grammar to initialize.
|
|
||||||
/// @param n_rules The number of rules.
|
|
||||||
/// @param start_rule_index The index of the root rule (the starting point of the grammar).
|
|
||||||
/// @return The initialized llama_grammar or nullptr if initialization failed.
|
|
||||||
LLAMA_API struct llama_grammar * llama_grammar_init(
|
|
||||||
const llama_grammar_element ** rules,
|
|
||||||
size_t n_rules,
|
|
||||||
size_t start_rule_index);
|
|
||||||
|
|
||||||
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
// user code can implement the interface below in order to create custom llama_sampler
|
||||||
|
struct llama_sampler_i {
|
||||||
|
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
|
||||||
|
void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL
|
||||||
|
void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
|
||||||
|
void (*reset) ( struct llama_sampler * smpl); // can be NULL
|
||||||
|
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||||
|
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||||
|
|
||||||
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||||
|
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
||||||
|
};
|
||||||
|
|
||||||
/// @details Apply constraints from grammar
|
struct llama_sampler {
|
||||||
LLAMA_API void llama_grammar_sample(
|
struct llama_sampler_i * iface;
|
||||||
const struct llama_grammar * grammar,
|
llama_sampler_context_t ctx;
|
||||||
const struct llama_context * ctx,
|
};
|
||||||
llama_token_data_array * candidates);
|
|
||||||
LLAMA_API DEPRECATED(void llama_sample_grammar(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const struct llama_grammar * grammar),
|
|
||||||
"use llama_grammar_sample instead");
|
|
||||||
|
|
||||||
/// @details Accepts the sampled token into the grammar
|
// mirror of llama_sampler_i:
|
||||||
LLAMA_API void llama_grammar_accept_token(
|
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||||
struct llama_grammar * grammar,
|
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||||
struct llama_context * ctx,
|
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
llama_token token);
|
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
|
||||||
|
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
||||||
|
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||||
|
|
||||||
//
|
// llama_sampler_chain
|
||||||
// Sampling functions
|
// a type of llama_sampler that can chain multiple samplers one after another
|
||||||
//
|
|
||||||
|
|
||||||
// Sets the current rng seed.
|
LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
|
||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
// important: takes ownership of the sampler object and will free it when llama_sampler_free is called
|
||||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
|
||||||
LLAMA_API void llama_sample_repetition_penalties(
|
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
|
||||||
struct llama_context * ctx,
|
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const llama_token * last_tokens,
|
|
||||||
size_t penalty_last_n,
|
|
||||||
float penalty_repeat,
|
|
||||||
float penalty_freq,
|
|
||||||
float penalty_present);
|
|
||||||
|
|
||||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||||
/// @param logits Logits extracted from the original generation context.
|
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
|
||||||
/// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
|
||||||
/// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
// available samplers:
|
||||||
LLAMA_API void llama_sample_apply_guidance(
|
|
||||||
struct llama_context * ctx,
|
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
|
||||||
float * logits,
|
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||||
float * logits_guidance,
|
|
||||||
float scale);
|
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
LLAMA_API void llama_sample_softmax(
|
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||||
struct llama_context * ctx,
|
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
||||||
llama_token_data_array * candidates);
|
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API void llama_sample_top_k(
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
int32_t k,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API void llama_sample_top_p(
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float p,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
|
||||||
LLAMA_API void llama_sample_min_p(
|
LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float p,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||||
LLAMA_API void llama_sample_tail_free(
|
LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float z,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||||
LLAMA_API void llama_sample_typical(
|
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
|
||||||
struct llama_context * ctx,
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float p,
|
|
||||||
size_t min_keep);
|
|
||||||
|
|
||||||
/// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
|
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
|
||||||
LLAMA_API void llama_sample_entropy(
|
LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates_p,
|
|
||||||
float min_temp,
|
|
||||||
float max_temp,
|
|
||||||
float exponent_val);
|
|
||||||
|
|
||||||
LLAMA_API void llama_sample_temp(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float temp);
|
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
@ -1121,36 +1094,62 @@ extern "C" {
|
|||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
LLAMA_API llama_token llama_sample_token_mirostat(
|
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
|
||||||
struct llama_context * ctx,
|
int32_t n_vocab,
|
||||||
llama_token_data_array * candidates,
|
uint32_t seed,
|
||||||
float tau,
|
float tau,
|
||||||
float eta,
|
float eta,
|
||||||
int32_t m,
|
int32_t m);
|
||||||
float * mu);
|
|
||||||
|
|
||||||
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||||
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||||
LLAMA_API llama_token llama_sample_token_mirostat_v2(
|
LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
|
||||||
struct llama_context * ctx,
|
uint32_t seed,
|
||||||
llama_token_data_array * candidates,
|
|
||||||
float tau,
|
float tau,
|
||||||
float eta,
|
float eta);
|
||||||
float * mu);
|
|
||||||
|
|
||||||
/// @details Selects the token with the highest probability.
|
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||||
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
|
const struct llama_model * model,
|
||||||
LLAMA_API llama_token llama_sample_token_greedy(
|
const char * grammar_str,
|
||||||
struct llama_context * ctx,
|
const char * grammar_root);
|
||||||
llama_token_data_array * candidates);
|
|
||||||
|
|
||||||
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
|
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||||
LLAMA_API llama_token llama_sample_token(
|
int32_t n_vocab, // llama_n_vocab()
|
||||||
struct llama_context * ctx,
|
llama_token special_eos_id, // llama_token_eos()
|
||||||
llama_token_data_array * candidates);
|
llama_token linefeed_id, // llama_token_nl()
|
||||||
|
int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
float penalty_repeat, // 1.0 = disabled
|
||||||
|
float penalty_freq, // 0.0 = disabled
|
||||||
|
float penalty_present, // 0.0 = disabled
|
||||||
|
bool penalize_nl, // consider newlines as a repeatable token
|
||||||
|
bool ignore_eos); // ignore the end-of-sequence token
|
||||||
|
|
||||||
|
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
|
int32_t n_vocab,
|
||||||
|
int32_t n_logit_bias,
|
||||||
|
const llama_logit_bias * logit_bias);
|
||||||
|
|
||||||
|
|
||||||
|
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||||
|
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
/// @details Sample and accept a token from the idx-th output of the last evaluation
|
||||||
|
//
|
||||||
|
// Shorthand for:
|
||||||
|
// const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
// llama_token_data_array cur_p = { ... init from logits ... };
|
||||||
|
// llama_sampler_apply(smpl, &cur_p);
|
||||||
|
// auto token = cur_p.data[cur_p.selected].id;
|
||||||
|
// llama_sampler_accept(smpl, token);
|
||||||
|
// return token;
|
||||||
|
// Returns the sampled token
|
||||||
|
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
|
||||||
|
|
||||||
|
// TODO: extend in the future
|
||||||
|
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model split
|
// Model split
|
||||||
@ -1166,12 +1165,6 @@ extern "C" {
|
|||||||
// Returns the split_prefix length.
|
// Returns the split_prefix length.
|
||||||
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
|
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
|
||||||
|
|
||||||
// Performance information
|
|
||||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
|
||||||
|
|
||||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
|
||||||
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
|
||||||
|
|
||||||
// Print system information
|
// Print system information
|
||||||
LLAMA_API const char * llama_print_system_info(void);
|
LLAMA_API const char * llama_print_system_info(void);
|
||||||
|
|
||||||
@ -1179,65 +1172,41 @@ extern "C" {
|
|||||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||||
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
|
||||||
|
|
||||||
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
|
//
|
||||||
|
// Performance utils
|
||||||
|
//
|
||||||
|
// NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_perf_context_data {
|
||||||
|
double t_start_ms;
|
||||||
|
double t_load_ms;
|
||||||
|
double t_p_eval_ms;
|
||||||
|
double t_eval_ms;
|
||||||
|
|
||||||
|
int32_t n_p_eval;
|
||||||
|
int32_t n_eval;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_perf_sampler_data {
|
||||||
|
double t_sample_ms;
|
||||||
|
|
||||||
|
int32_t n_sample;
|
||||||
|
};
|
||||||
|
|
||||||
|
LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx);
|
||||||
|
LLAMA_API void llama_perf_context_print(const struct llama_context * ctx);
|
||||||
|
LLAMA_API void llama_perf_context_reset( struct llama_context * ctx);
|
||||||
|
|
||||||
|
// NOTE: the following work only with samplers constructed via llama_sampler_chain_init
|
||||||
|
LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain);
|
||||||
|
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
|
||||||
|
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
|
||||||
|
|
||||||
|
LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
|
|
||||||
#ifdef LLAMA_API_INTERNAL
|
|
||||||
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
struct ggml_tensor;
|
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
|
||||||
struct llama_context * ctx
|
|
||||||
);
|
|
||||||
|
|
||||||
struct llama_partial_utf8 {
|
|
||||||
uint32_t value; // bit value so far (unshifted)
|
|
||||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_grammar_candidate {
|
|
||||||
size_t index;
|
|
||||||
const uint32_t * code_points;
|
|
||||||
llama_partial_utf8 partial_utf8;
|
|
||||||
};
|
|
||||||
|
|
||||||
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
|
||||||
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
|
||||||
|
|
||||||
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
|
||||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
|
||||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
|
||||||
|
|
||||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
|
||||||
|
|
||||||
void llama_grammar_accept(
|
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stacks & stacks,
|
|
||||||
const uint32_t chr,
|
|
||||||
llama_grammar_stacks & new_stacks);
|
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
|
||||||
const llama_grammar_rules & rules,
|
|
||||||
const llama_grammar_stack & stack,
|
|
||||||
const llama_grammar_candidates & candidates);
|
|
||||||
|
|
||||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
|
||||||
const std::string & src,
|
|
||||||
llama_partial_utf8 partial_start);
|
|
||||||
|
|
||||||
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
|
|
||||||
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
|
|
||||||
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
|
|
||||||
|
|
||||||
#endif // LLAMA_API_INTERNAL
|
|
||||||
|
|
||||||
#endif // LLAMA_H
|
#endif // LLAMA_H
|
||||||
|
@ -314,7 +314,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// tune these to your liking
|
// tune these to your liking
|
||||||
lcparams.n_ctx = 2048;
|
lcparams.n_ctx = 2048;
|
||||||
lcparams.seed = 1;
|
|
||||||
lcparams.n_threads = params.n_threads;
|
lcparams.n_threads = params.n_threads;
|
||||||
lcparams.flash_attn = params.flash_attn;
|
lcparams.flash_attn = params.flash_attn;
|
||||||
|
|
||||||
@ -402,6 +401,26 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
|
llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);
|
||||||
|
|
||||||
|
// init sampler
|
||||||
|
const float top_k = 5;
|
||||||
|
const float top_p = 0.80f;
|
||||||
|
const float temp = 0.30f;
|
||||||
|
|
||||||
|
const int seed = 0;
|
||||||
|
|
||||||
|
auto sparams = llama_sampler_chain_default_params();
|
||||||
|
|
||||||
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
|
if (temp > 0.0f) {
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k));
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1));
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp));
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
|
||||||
|
} else {
|
||||||
|
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||||
|
}
|
||||||
|
|
||||||
// init session
|
// init session
|
||||||
std::string path_session = params.path_session;
|
std::string path_session = params.path_session;
|
||||||
std::vector<llama_token> session_tokens;
|
std::vector<llama_token> session_tokens;
|
||||||
@ -700,54 +719,13 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// out of user input, sample next token
|
// out of user input, sample next token
|
||||||
const float top_k = 5;
|
|
||||||
const float top_p = 0.80f;
|
|
||||||
const float temp = 0.30f;
|
|
||||||
const float repeat_penalty = 1.1764f;
|
|
||||||
|
|
||||||
const int repeat_last_n = 256;
|
|
||||||
|
|
||||||
if (!path_session.empty() && need_to_save_session) {
|
if (!path_session.empty() && need_to_save_session) {
|
||||||
need_to_save_session = false;
|
need_to_save_session = false;
|
||||||
llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
llama_state_save_file(ctx_llama, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token id = 0;
|
const llama_token id = llama_sampler_sample(smpl, ctx_llama, -1);
|
||||||
|
|
||||||
{
|
|
||||||
auto logits = llama_get_logits(ctx_llama);
|
|
||||||
auto n_vocab = llama_n_vocab(model_llama);
|
|
||||||
|
|
||||||
logits[llama_token_eos(model_llama)] = 0;
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
// apply repeat penalty
|
|
||||||
const float nl_logit = logits[llama_token_nl(model_llama)];
|
|
||||||
|
|
||||||
llama_sample_repetition_penalties(ctx_llama, &candidates_p,
|
|
||||||
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
|
||||||
repeat_last_n, repeat_penalty, 0.0, 0.0f);
|
|
||||||
|
|
||||||
logits[llama_token_nl(model_llama)] = nl_logit;
|
|
||||||
|
|
||||||
if (temp <= 0) {
|
|
||||||
// Greedy sampling
|
|
||||||
id = llama_sample_token_greedy(ctx_llama, &candidates_p);
|
|
||||||
} else {
|
|
||||||
// Temperature sampling
|
|
||||||
llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1);
|
|
||||||
llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1);
|
|
||||||
llama_sample_temp (ctx_llama, &candidates_p, temp);
|
|
||||||
id = llama_sample_token(ctx_llama, &candidates_p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (id != llama_token_eos(model_llama)) {
|
if (id != llama_token_eos(model_llama)) {
|
||||||
// add it to the context
|
// add it to the context
|
||||||
@ -797,8 +775,14 @@ int main(int argc, char ** argv) {
|
|||||||
whisper_print_timings(ctx_wsp);
|
whisper_print_timings(ctx_wsp);
|
||||||
whisper_free(ctx_wsp);
|
whisper_free(ctx_wsp);
|
||||||
|
|
||||||
llama_print_timings(ctx_llama);
|
llama_perf_sampler_print(smpl);
|
||||||
|
llama_perf_context_print(ctx_llama);
|
||||||
|
|
||||||
|
llama_sampler_free(smpl);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx_llama);
|
llama_free(ctx_llama);
|
||||||
|
|
||||||
|
llama_backend_free();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
#include "unicode-data.h"
|
#include "unicode-data.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
@ -177,7 +177,7 @@ static bool ggml_graph_compute_helper(
|
|||||||
int n_threads,
|
int n_threads,
|
||||||
ggml_abort_callback abort_callback,
|
ggml_abort_callback abort_callback,
|
||||||
void * abort_callback_data) {
|
void * abort_callback_data) {
|
||||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
|
||||||
|
|
||||||
plan.abort_callback = abort_callback;
|
plan.abort_callback = abort_callback;
|
||||||
plan.abort_callback_data = abort_callback_data;
|
plan.abort_callback_data = abort_callback_data;
|
||||||
@ -2894,7 +2894,7 @@ static bool whisper_decode_internal(
|
|||||||
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
logits = gf->nodes[gf->n_nodes - 1];
|
logits = ggml_graph_node(gf, -1);
|
||||||
|
|
||||||
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
|
||||||
return false;
|
return false;
|
||||||
|
Loading…
Reference in New Issue
Block a user