diff --git a/README.md b/README.md
index 2eb3a098..4b548546 100644
--- a/README.md
+++ b/README.md
@@ -93,7 +93,8 @@ On Linux/MacOS more details are [here](docs/Linux_Osx_Usage.md)
## Options
-
+* `--model`: the name of the model to be used. The model should be placed in models folder (default: gpt4all-lora-quantized.bin)
+* `--seed`: the random seed for reproductibility. If fixed, it is possible to reproduce the outputs exactly (default: random)
* `--port`: the port on which to run the server (default: 9600)
* `--host`: the host address on which to run the server (default: localhost)
* `--temp`: the sampling temperature for the model (default: 0.1)
diff --git a/app.py b/app.py
index 7df4b835..bb435e3b 100644
--- a/app.py
+++ b/app.py
@@ -1,6 +1,7 @@
import argparse
import json
import re
+import random
import sqlite3
import traceback
from datetime import datetime
@@ -166,11 +167,12 @@ app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
class Gpt4AllWebUI:
- def __init__(self, chatbot_bindings, _app, db_path="database.db") -> None:
+ def __init__(self, _app, args) -> None:
+ self.args = args
self.current_discussion = None
- self.chatbot_bindings = chatbot_bindings
self.app = _app
- self.db_path = db_path
+ self.db_path = args.db_path
+
self.add_endpoint("/", "", self.index, methods=["GET"])
self.add_endpoint("/export", "export", self.export, methods=["GET"])
self.add_endpoint(
@@ -195,21 +197,41 @@ class Gpt4AllWebUI:
"/update_message", "update_message", self.update_message, methods=["GET"]
)
- conditionning_message = """
+
+ # Create chatbot
+ self.chatbot_bindings = self.create_chatbot()
+ # Chatbot conditionning
+ self.condition_chatbot()
+
+
+ def create_chatbot(self):
+ return Model(
+ ggml_model=f"./models/{self.args.model}",
+ n_ctx=self.args.ctx_size,
+ seed=self.args.seed,
+ )
+
+ def condition_chatbot(self, conditionning_message = """
Instruction: Act as GPT4All. A kind and helpful AI bot built to help users solve problems.
Start by welcoming the user then stop sending text.
-GPT4All:"""
+GPT4All:Welcome! I'm here to assist you with anything you need. What can I do for you today?"""
+ ):
+
self.prepare_query(conditionning_message)
- chatbot_bindings.generate(
+ self.chatbot_bindings.generate(
conditionning_message,
- n_predict=55,
new_text_callback=self.new_text_callback,
+
+ n_predict=len(conditionning_message),
+ temp=self.args.temp,
+ top_k=self.args.top_k,
+ top_p=self.args.top_p,
+ repeat_penalty=self.args.repeat_penalty,
+ repeat_last_n = self.args.repeat_last_n,
+ #seed=self.args.seed,
n_threads=8,
)
- print(f"Bot said:{self.bot_says}")
- # Chatbot conditionning
- # response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:")
- # print(response)
+ print(f"Bot said:{self.bot_says}")
def prepare_query(self, message):
self.bot_says = ""
@@ -291,12 +313,19 @@ GPT4All:"""
)
)
- self.current_message = "User: " + message + "\nGPT4All:"
+ self.current_message = "\nUser: " + message + "\nGPT4All: "
self.prepare_query(self.current_message)
- chatbot_model_bindings.generate(
+ self.chatbot_bindings.generate(
self.current_message,
- n_predict=55,
new_text_callback=self.new_text_callback,
+
+ n_predict=len(self.current_message)+args.n_predict,
+ temp=self.args.temp,
+ top_k=self.args.top_k,
+ top_p=self.args.top_p,
+ repeat_penalty=self.args.repeat_penalty,
+ repeat_last_n = self.args.repeat_last_n,
+ #seed=self.args.seed,
n_threads=8,
)
@@ -363,6 +392,21 @@ GPT4All:"""
discussion_id = data["id"]
self.current_discussion = Discussion(discussion_id, self.db_path)
messages = self.current_discussion.get_messages()
+ full_message = ""
+ for message in messages:
+ full_message += message['sender'] + ": " + message['content'] + "\n"
+
+ self.chatbot_bindings.generate(
+ full_message,
+ new_text_callback=self.new_text_callback,
+ n_predict=len(messages),
+ temp=self.args.temp,
+ top_k=self.args.top_k,
+ top_p=self.args.top_p,
+ repeat_penalty= self.args.repeat_penalty,
+ repeat_last_n = self.args.repeat_last_n,
+ n_threads=8,
+ )
return jsonify(messages)
def delete_discussion(self):
@@ -395,24 +439,30 @@ GPT4All:"""
# Get the current timestamp
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- # add a new discussion
- # self.chatbot_bindings.close()
- # self.chatbot_bindings.open()
-
+ # Create chatbot
+ self.chatbot_bindings = self.create_chatbot()
+ # Chatbot conditionning
+ self.condition_chatbot()
# Return a success response
return json.dumps({"id": self.current_discussion.discussion_id, "time": timestamp})
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start the chatbot Flask app.")
+ parser.add_argument(
+ "-s", "--seed", type=int, default=0, help="Force using a specific model."
+ )
+ parser.add_argument(
+ "-m", "--model", type=str, default="gpt4all-lora-quantized.bin", help="Force using a specific model."
+ )
parser.add_argument(
"--temp", type=float, default=0.1, help="Temperature parameter for the model."
)
parser.add_argument(
"--n_predict",
type=int,
- default=128,
+ default=256,#128,
help="Number of tokens to predict at each step.",
)
parser.add_argument(
@@ -433,7 +483,7 @@ if __name__ == "__main__":
parser.add_argument(
"--ctx_size",
type=int,
- default=2048,
+ default=512,#2048,
help="Size of the context window for the model.",
)
parser.add_argument(
@@ -450,26 +500,10 @@ if __name__ == "__main__":
"--db_path", type=str, default="database.db", help="Database path"
)
parser.set_defaults(debug=False)
-
args = parser.parse_args()
- chatbot_model_bindings = Model(
- ggml_model="./models/gpt4all-lora-quantized-ggml.bin", n_ctx=512
- )
-
- # Old Code
- # GPT4All(decoder_config = {
- # 'temp': args.temp,
- # 'n_predict':args.n_predict,
- # 'top_k':args.top_k,
- # 'top_p':args.top_p,
- # #'color': True,#"## Instruction",
- # 'repeat_penalty': args.repeat_penalty,
- # 'repeat_last_n':args.repeat_last_n,
- # 'ctx_size': args.ctx_size
- # })
check_discussion_db(args.db_path)
- bot = Gpt4AllWebUI(chatbot_model_bindings, app, args.db_path)
+ bot = Gpt4AllWebUI(app, args)
if args.debug:
app.run(debug=True, host=args.host, port=args.port)
diff --git a/run.bat b/run.bat
index 52f9c940..78b98076 100644
--- a/run.bat
+++ b/run.bat
@@ -37,4 +37,4 @@ echo HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH
echo on
call env/Scripts/activate.bat
-python app.py
\ No newline at end of file
+python app.py %*
\ No newline at end of file
diff --git a/static/js/chat.js b/static/js/chat.js
index c39377bd..b2c86e69 100644
--- a/static/js/chat.js
+++ b/static/js/chat.js
@@ -527,6 +527,7 @@ const welcome_message = `
- Act as ChefAI an AI that has the ability to create recipes for any occasion. Instruction: Give me a recipe for my next anniversary.
+