diff --git a/README.md b/README.md index 2eb3a098..4b548546 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,8 @@ On Linux/MacOS more details are [here](docs/Linux_Osx_Usage.md) ## Options - +* `--model`: the name of the model to be used. The model should be placed in models folder (default: gpt4all-lora-quantized.bin) +* `--seed`: the random seed for reproductibility. If fixed, it is possible to reproduce the outputs exactly (default: random) * `--port`: the port on which to run the server (default: 9600) * `--host`: the host address on which to run the server (default: localhost) * `--temp`: the sampling temperature for the model (default: 0.1) diff --git a/app.py b/app.py index 7df4b835..bb435e3b 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,7 @@ import argparse import json import re +import random import sqlite3 import traceback from datetime import datetime @@ -166,11 +167,12 @@ app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static") class Gpt4AllWebUI: - def __init__(self, chatbot_bindings, _app, db_path="database.db") -> None: + def __init__(self, _app, args) -> None: + self.args = args self.current_discussion = None - self.chatbot_bindings = chatbot_bindings self.app = _app - self.db_path = db_path + self.db_path = args.db_path + self.add_endpoint("/", "", self.index, methods=["GET"]) self.add_endpoint("/export", "export", self.export, methods=["GET"]) self.add_endpoint( @@ -195,21 +197,41 @@ class Gpt4AllWebUI: "/update_message", "update_message", self.update_message, methods=["GET"] ) - conditionning_message = """ + + # Create chatbot + self.chatbot_bindings = self.create_chatbot() + # Chatbot conditionning + self.condition_chatbot() + + + def create_chatbot(self): + return Model( + ggml_model=f"./models/{self.args.model}", + n_ctx=self.args.ctx_size, + seed=self.args.seed, + ) + + def condition_chatbot(self, conditionning_message = """ Instruction: Act as GPT4All. A kind and helpful AI bot built to help users solve problems. Start by welcoming the user then stop sending text. -GPT4All:""" +GPT4All:Welcome! I'm here to assist you with anything you need. What can I do for you today?""" + ): + self.prepare_query(conditionning_message) - chatbot_bindings.generate( + self.chatbot_bindings.generate( conditionning_message, - n_predict=55, new_text_callback=self.new_text_callback, + + n_predict=len(conditionning_message), + temp=self.args.temp, + top_k=self.args.top_k, + top_p=self.args.top_p, + repeat_penalty=self.args.repeat_penalty, + repeat_last_n = self.args.repeat_last_n, + #seed=self.args.seed, n_threads=8, ) - print(f"Bot said:{self.bot_says}") - # Chatbot conditionning - # response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:") - # print(response) + print(f"Bot said:{self.bot_says}") def prepare_query(self, message): self.bot_says = "" @@ -291,12 +313,19 @@ GPT4All:""" ) ) - self.current_message = "User: " + message + "\nGPT4All:" + self.current_message = "\nUser: " + message + "\nGPT4All: " self.prepare_query(self.current_message) - chatbot_model_bindings.generate( + self.chatbot_bindings.generate( self.current_message, - n_predict=55, new_text_callback=self.new_text_callback, + + n_predict=len(self.current_message)+args.n_predict, + temp=self.args.temp, + top_k=self.args.top_k, + top_p=self.args.top_p, + repeat_penalty=self.args.repeat_penalty, + repeat_last_n = self.args.repeat_last_n, + #seed=self.args.seed, n_threads=8, ) @@ -363,6 +392,21 @@ GPT4All:""" discussion_id = data["id"] self.current_discussion = Discussion(discussion_id, self.db_path) messages = self.current_discussion.get_messages() + full_message = "" + for message in messages: + full_message += message['sender'] + ": " + message['content'] + "\n" + + self.chatbot_bindings.generate( + full_message, + new_text_callback=self.new_text_callback, + n_predict=len(messages), + temp=self.args.temp, + top_k=self.args.top_k, + top_p=self.args.top_p, + repeat_penalty= self.args.repeat_penalty, + repeat_last_n = self.args.repeat_last_n, + n_threads=8, + ) return jsonify(messages) def delete_discussion(self): @@ -395,24 +439,30 @@ GPT4All:""" # Get the current timestamp timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - # add a new discussion - # self.chatbot_bindings.close() - # self.chatbot_bindings.open() - + # Create chatbot + self.chatbot_bindings = self.create_chatbot() + # Chatbot conditionning + self.condition_chatbot() # Return a success response return json.dumps({"id": self.current_discussion.discussion_id, "time": timestamp}) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Start the chatbot Flask app.") + parser.add_argument( + "-s", "--seed", type=int, default=0, help="Force using a specific model." + ) + parser.add_argument( + "-m", "--model", type=str, default="gpt4all-lora-quantized.bin", help="Force using a specific model." + ) parser.add_argument( "--temp", type=float, default=0.1, help="Temperature parameter for the model." ) parser.add_argument( "--n_predict", type=int, - default=128, + default=256,#128, help="Number of tokens to predict at each step.", ) parser.add_argument( @@ -433,7 +483,7 @@ if __name__ == "__main__": parser.add_argument( "--ctx_size", type=int, - default=2048, + default=512,#2048, help="Size of the context window for the model.", ) parser.add_argument( @@ -450,26 +500,10 @@ if __name__ == "__main__": "--db_path", type=str, default="database.db", help="Database path" ) parser.set_defaults(debug=False) - args = parser.parse_args() - chatbot_model_bindings = Model( - ggml_model="./models/gpt4all-lora-quantized-ggml.bin", n_ctx=512 - ) - - # Old Code - # GPT4All(decoder_config = { - # 'temp': args.temp, - # 'n_predict':args.n_predict, - # 'top_k':args.top_k, - # 'top_p':args.top_p, - # #'color': True,#"## Instruction", - # 'repeat_penalty': args.repeat_penalty, - # 'repeat_last_n':args.repeat_last_n, - # 'ctx_size': args.ctx_size - # }) check_discussion_db(args.db_path) - bot = Gpt4AllWebUI(chatbot_model_bindings, app, args.db_path) + bot = Gpt4AllWebUI(app, args) if args.debug: app.run(debug=True, host=args.host, port=args.port) diff --git a/run.bat b/run.bat index 52f9c940..78b98076 100644 --- a/run.bat +++ b/run.bat @@ -37,4 +37,4 @@ echo HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH echo on call env/Scripts/activate.bat -python app.py \ No newline at end of file +python app.py %* \ No newline at end of file diff --git a/static/js/chat.js b/static/js/chat.js index c39377bd..b2c86e69 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -527,6 +527,7 @@ const welcome_message = ` - Act as ChefAI an AI that has the ability to create recipes for any occasion. Instruction: Give me a recipe for my next anniversary.
+
Welcome! I'm here to assist you with anything you need. What can I do for you today?
`; //welcome_message = add_collapsible_div("Note:", text, 'hints');