talk-llama : update to latest llama.cpp

This commit is contained in:
Georgi Gerganov
2023-09-15 20:06:31 +03:00
parent 80c1512fd5
commit 1ca4041b86
6 changed files with 5671 additions and 1971 deletions

View File

@ -25,6 +25,20 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return std::string(result.data(), result.size());
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
@ -33,14 +47,14 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
bool verbose_prompt = false;
std::string person = "Georgi";
@ -235,7 +249,7 @@ int main(int argc, char ** argv) {
// llama init
llama_init_backend();
llama_backend_init(true);
auto lparams = llama_context_default_params();
@ -244,7 +258,9 @@ int main(int argc, char ** argv) {
lparams.seed = 1;
lparams.f16_kv = true;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
struct llama_model * model_llama = llama_load_model_from_file(params.model_llama.c_str(), lparams);
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lparams);
// print some info about the processing
{
@ -267,7 +283,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
@ -278,8 +293,6 @@ int main(int argc, char ** argv) {
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
@ -514,7 +527,7 @@ int main(int argc, char ** argv) {
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
// printf("%s", llama_token_to_piece(ctx_llama, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
@ -582,7 +595,7 @@ int main(int argc, char ** argv) {
auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(ctx_llama);
logits[llama_token_eos()] = 0;
logits[llama_token_eos(ctx_llama)] = 0;
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -593,13 +606,13 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// apply repeat penalty
const float nl_logit = logits[llama_token_nl()];
const float nl_logit = logits[llama_token_nl(ctx_llama)];
llama_sample_repetition_penalty(ctx_llama, &candidates_p,
embd_inp.data() + std::max(0, n_past - repeat_last_n),
repeat_last_n, repeat_penalty);
logits[llama_token_nl()] = nl_logit;
logits[llama_token_nl(ctx_llama)] = nl_logit;
if (temp <= 0) {
// Greedy sampling
@ -613,22 +626,22 @@ int main(int argc, char ** argv) {
}
}
if (id != llama_token_eos()) {
if (id != llama_token_eos(ctx_llama)) {
// add it to the context
embd.push_back(id);
text_to_speak += llama_token_to_str(ctx_llama, id);
text_to_speak += llama_token_to_piece(ctx_llama, id);
printf("%s", llama_token_to_str(ctx_llama, id));
printf("%s", llama_token_to_piece(ctx_llama, id).c_str());
}
}
{
std::string last_output;
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
last_output += llama_token_to_piece(ctx_llama, embd_inp[i]);
}
last_output += llama_token_to_str(ctx_llama, embd[0]);
last_output += llama_token_to_piece(ctx_llama, embd[0]);
for (std::string & antiprompt : antiprompts) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
@ -655,8 +668,6 @@ int main(int argc, char ** argv) {
}
audio.clear();
++n_iter;
}
}
}