talk.wasm : polishing + adding many AI personalities

This commit is contained in:
Georgi Gerganov
2022-11-22 20:10:20 +02:00
parent 385236d1d3
commit 9aea96f774
4 changed files with 383 additions and 48 deletions

View File

@ -988,7 +988,7 @@ std::atomic<bool> g_running(false);
bool g_force_speak = false;
std::string g_text_to_speak = "";
std::string g_status = "idle";
std::string g_status = "";
std::string g_status_forced = "";
std::string gpt2_gen_text(const std::string & prompt) {
@ -997,7 +997,7 @@ std::string gpt2_gen_text(const std::string & prompt) {
std::vector<float> embd_w;
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, g_gpt2.prompt_base + prompt);
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, prompt);
g_gpt2.n_predict = std::min(g_gpt2.n_predict, g_gpt2.model.hparams.n_ctx - (int) embd_inp.size());
@ -1088,6 +1088,8 @@ void talk_main(size_t index) {
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}
printf("talk: using %d threads\n", N_THREAD);
std::vector<float> pcmf32;
auto & ctx = g_contexts[index];
@ -1214,9 +1216,15 @@ void talk_main(size_t index) {
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());
std::string text_to_speak;
std::string prompt_base;
{
std::lock_guard<std::mutex> lock(g_mutex);
prompt_base = g_gpt2.prompt_base;
}
if (tokens.size() > 0) {
text_to_speak = gpt2_gen_text(text_heard + "\n");
text_to_speak = gpt2_gen_text(prompt_base + text_heard + "\n");
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
@ -1224,36 +1232,36 @@ void talk_main(size_t index) {
// remove first 2 lines of base prompt
{
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
prompt_base = prompt_base.substr(pos + 1);
}
}
g_gpt2.prompt_base += text_heard + "\n" + text_to_speak + "\n";
prompt_base += text_heard + "\n" + text_to_speak + "\n";
} else {
text_to_speak = gpt2_gen_text("");
text_to_speak = gpt2_gen_text(prompt_base);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));
std::lock_guard<std::mutex> lock(g_mutex);
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
prompt_base = prompt_base.substr(pos + 1);
}
g_gpt2.prompt_base += text_to_speak + "\n";
prompt_base += text_to_speak + "\n";
}
printf("gpt-2: %s\n", text_to_speak.c_str());
//printf("========================\n");
//printf("gpt-2: prompt_base:\n'%s'\n", g_gpt2.prompt_base.c_str());
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
//printf("========================\n");
{
@ -1261,6 +1269,7 @@ void talk_main(size_t index) {
t_last = std::chrono::high_resolution_clock::now();
g_text_to_speak = text_to_speak;
g_pcmf32.clear();
g_gpt2.prompt_base = prompt_base;
}
talk_set_status("speaking ...");
@ -1376,4 +1385,11 @@ EMSCRIPTEN_BINDINGS(talk) {
g_status_forced = status;
}
}));
emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_gpt2.prompt_base = prompt;
}
}));
}