mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-30 16:13:57 +00:00
talk-llama : fix session prompt load (#854)
This commit is contained in:
parent
b806420873
commit
0bf680fea2
@ -333,27 +333,10 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
|
||||||
|
|
||||||
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
|
||||||
|
|
||||||
printf("\n");
|
|
||||||
printf("%s : initializing - please wait ...\n", __func__);
|
|
||||||
|
|
||||||
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.verbose_prompt) {
|
|
||||||
fprintf(stdout, "\n");
|
|
||||||
fprintf(stdout, "%s", prompt_llama.c_str());
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
|
|
||||||
// init session
|
// init session
|
||||||
std::string path_session = params.path_session;
|
std::string path_session = params.path_session;
|
||||||
std::vector<llama_token> session_tokens;
|
std::vector<llama_token> session_tokens;
|
||||||
|
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
||||||
|
|
||||||
if (!path_session.empty()) {
|
if (!path_session.empty()) {
|
||||||
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
|
fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str());
|
||||||
@ -370,6 +353,9 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
session_tokens.resize(n_token_count_out);
|
session_tokens.resize(n_token_count_out);
|
||||||
|
for (size_t i = 0; i < session_tokens.size(); i++) {
|
||||||
|
embd_inp[i] = session_tokens[i];
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
|
||||||
} else {
|
} else {
|
||||||
@ -377,6 +363,22 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// evaluate the initial prompt
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
printf("%s : initializing - please wait ...\n", __func__);
|
||||||
|
|
||||||
|
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.verbose_prompt) {
|
||||||
|
fprintf(stdout, "\n");
|
||||||
|
fprintf(stdout, "%s", prompt_llama.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
// debug message about similarity of saved session, if applicable
|
// debug message about similarity of saved session, if applicable
|
||||||
size_t n_matching_session_tokens = 0;
|
size_t n_matching_session_tokens = 0;
|
||||||
if (session_tokens.size()) {
|
if (session_tokens.size()) {
|
||||||
@ -417,7 +419,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
int n_past = n_keep;
|
int n_past = n_keep;
|
||||||
int n_prev = 64; // TODO arg
|
int n_prev = 64; // TODO arg
|
||||||
int n_session_consumed = 0;
|
int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0;
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
@ -494,6 +496,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
||||||
|
|
||||||
|
// Append the new input tokens to the session_tokens vector
|
||||||
|
if (!path_session.empty()) {
|
||||||
|
session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
// text inference
|
// text inference
|
||||||
bool done = false;
|
bool done = false;
|
||||||
std::string text_to_speak;
|
std::string text_to_speak;
|
||||||
@ -539,20 +546,21 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (embd.size() > 0 && !path_session.empty()) {
|
||||||
|
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||||
|
n_session_consumed = session_tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
|
|
||||||
|
|
||||||
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
||||||
n_past += embd.size();
|
n_past += embd.size();
|
||||||
if (embd.size() > 0 && !path_session.empty()) {
|
|
||||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
|
||||||
n_session_consumed = session_tokens.size();
|
|
||||||
}
|
|
||||||
embd.clear();
|
embd.clear();
|
||||||
|
|
||||||
if (done) break;
|
if (done) break;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user