From cc24c28188c3a46f9e035fc6137e79d126defe3e Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 12 Apr 2023 22:36:03 +0200 Subject: [PATCH] configuration is now percistant --- app.py | 129 ++++++++++++++++++++----------------------- config.py | 22 ++++++++ configs/default.yaml | 13 +++++ requirements.txt | 1 + templates/chat.html | 17 +++--- 5 files changed, 104 insertions(+), 78 deletions(-) create mode 100644 config.py create mode 100644 configs/default.yaml diff --git a/app.py b/app.py index 8ce5326e..e79a44f6 100644 --- a/app.py +++ b/app.py @@ -32,14 +32,15 @@ from pathlib import Path import gc app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static") import time +from config import load_config class Gpt4AllWebUI: - def __init__(self, _app, args) -> None: - self.args = args + def __init__(self, _app, config:dict) -> None: + self.config = config self.current_discussion = None self.app = _app - self.db_path = args.db_path + self.db_path = config["db_path"] self.db = DiscussionsDB(self.db_path) # If the database is empty, populate it with tables self.db.populate() @@ -92,7 +93,7 @@ class Gpt4AllWebUI: ) self.add_endpoint( - "/get_args", "get_args", self.get_args, methods=["GET"] + "/get_config", "get_config", self.get_config, methods=["GET"] ) self.add_endpoint( @@ -128,9 +129,9 @@ class Gpt4AllWebUI: def create_chatbot(self): return Model( - ggml_model=f"./models/{self.args.model}", - n_ctx=self.args.ctx_size, - seed=self.args.seed, + ggml_model=f"./models/{self.config['model']}", + n_ctx=self.config['ctx_size'], + seed=self.config['seed'], ) def condition_chatbot(self, conditionning_message = """ @@ -151,23 +152,6 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo self.full_message_list.append(conditionning_message) - # self.prepare_query(conditionning_message) - # self.chatbot_bindings.generate( - # conditionning_message, - # new_text_callback=self.new_text_callback, - - # n_predict=len(conditionning_message), - # temp=self.args.temp, - # top_k=self.args.top_k, - # top_p=self.args.top_p, - # repeat_penalty=self.args.repeat_penalty, - # repeat_last_n = self.args.repeat_last_n, - # seed=self.args.seed, - # n_threads=8 - # ) - # print(f"Bot said:{self.bot_says}") - - def prepare_query(self): self.bot_says = "" self.full_text = "" @@ -229,13 +213,13 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo self.chatbot_bindings.generate( self.prompt_message,#self.full_message,#self.current_message, new_text_callback=self.new_text_callback, - n_predict=len(self.current_message)+self.args.n_predict, - temp=self.args.temp, - top_k=self.args.top_k, - top_p=self.args.top_p, - repeat_penalty=self.args.repeat_penalty, - repeat_last_n = self.args.repeat_last_n, - #seed=self.args.seed, + n_predict=len(self.current_message)+self.config['n_predict'], + temp=self.config['temp'], + top_k=self.config['top_k'], + top_p=self.config['top_p'], + repeat_penalty=self.config['repeat_penalty'], + repeat_last_n = self.config['repeat_last_n'], + #seed=self.config['seed'], n_threads=8 ) self.generating=False @@ -327,11 +311,11 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo self.prompt_message,#full_message, new_text_callback=self.new_text_callback, n_predict=0,#len(full_message), - temp=self.args.temp, - top_k=self.args.top_k, - top_p=self.args.top_p, - repeat_penalty= self.args.repeat_penalty, - repeat_last_n = self.args.repeat_last_n, + temp=self.config['temp'], + top_k=self.config['top_k'], + top_p=self.config['top_p'], + repeat_penalty= self.config['repeat_penalty'], + repeat_last_n = self.config['repeat_last_n'], n_threads=8 ) @@ -390,33 +374,33 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo def update_model_params(self): data = request.get_json() model = str(data["model"]) - if self.args.model != model: + if self.config['model'] != model: print("New model selected") - self.args.model = model + self.config['model'] = model self.prepare_a_new_chatbot() - self.args.n_predict = int(data["nPredict"]) - self.args.seed = int(data["seed"]) + self.config['n_predict'] = int(data["nPredict"]) + self.config['seed'] = int(data["seed"]) - self.args.temp = float(data["temp"]) - self.args.top_k = int(data["topK"]) - self.args.top_p = float(data["topP"]) - self.args.repeat_penalty = int(data["repeatPenalty"]) - self.args.repeat_last_n = int(data["repeatLastN"]) + self.config['temp'] = float(data["temp"]) + self.config['top_k'] = int(data["topK"]) + self.config['top_p'] = float(data["topP"]) + self.config['repeat_penalty'] = float(data["repeatPenalty"]) + self.config['repeat_last_n'] = int(data["repeatLastN"]) print("Parameters changed to:") - print(f"\tTemperature:{self.args.temp}") - print(f"\tNPredict:{self.args.n_predict}") - print(f"\tSeed:{self.args.seed}") - print(f"\top_k:{self.args.top_k}") - print(f"\top_p:{self.args.top_p}") - print(f"\trepeat_penalty:{self.args.repeat_penalty}") - print(f"\trepeat_last_n:{self.args.repeat_last_n}") + print(f"\tTemperature:{self.config['temp']}") + print(f"\tNPredict:{self.config['n_predict']}") + print(f"\tSeed:{self.config['seed']}") + print(f"\top_k:{self.config['top_k']}") + print(f"\top_p:{self.config['top_p']}") + print(f"\trepeat_penalty:{self.config['repeat_penalty']}") + print(f"\trepeat_last_n:{self.config['repeat_last_n']}") return jsonify({"status":"ok"}) - def get_args(self): - return jsonify(self.args) + def get_config(self): + return jsonify(self.config) def help(self): return render_template("help.html") @@ -432,40 +416,40 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo if __name__ == "__main__": parser = argparse.ArgumentParser(description="Start the chatbot Flask app.") parser.add_argument( - "-s", "--seed", type=int, default=0, help="Force using a specific model." + "-s", "--seed", type=int, default=None, help="Force using a specific model." ) parser.add_argument( - "-m", "--model", type=str, default="gpt4all-lora-quantized-ggml.bin", help="Force using a specific model." + "-m", "--model", type=str, default=None, help="Force using a specific model." ) parser.add_argument( - "--temp", type=float, default=0.1, help="Temperature parameter for the model." + "--temp", type=float, default=None, help="Temperature parameter for the model." ) parser.add_argument( "--n_predict", type=int, - default=256, + default=None, help="Number of tokens to predict at each step.", ) parser.add_argument( - "--top_k", type=int, default=40, help="Value for the top-k sampling." + "--top_k", type=int, default=None, help="Value for the top-k sampling." ) parser.add_argument( - "--top_p", type=float, default=0.95, help="Value for the top-p sampling." + "--top_p", type=float, default=None, help="Value for the top-p sampling." ) parser.add_argument( - "--repeat_penalty", type=float, default=1.3, help="Penalty for repeated tokens." + "--repeat_penalty", type=float, default=None, help="Penalty for repeated tokens." ) parser.add_argument( "--repeat_last_n", type=int, - default=64, + default=None, help="Number of previous tokens to consider for the repeat penalty.", ) parser.add_argument( "--ctx_size", type=int, - default=512,#2048, + default=None,#2048, help="Size of the context window for the model.", ) parser.add_argument( @@ -477,19 +461,26 @@ if __name__ == "__main__": parser.add_argument( "--host", type=str, default="localhost", help="the hostname to listen on" ) - parser.add_argument("--port", type=int, default=9600, help="the port to listen on") + parser.add_argument("--port", type=int, default=None, help="the port to listen on") parser.add_argument( - "--db_path", type=str, default="database.db", help="Database path" + "--db_path", type=str, default=None, help="Database path" ) parser.set_defaults(debug=False) args = parser.parse_args() + config_file_path = "configs/default.yaml" + config = load_config(config_file_path) + + # Override values in config with command-line arguments + for arg_name, arg_value in vars(args).items(): + if arg_value is not None: + config[arg_name] = arg_value executor = ThreadPoolExecutor(max_workers=2) app.config['executor'] = executor - bot = Gpt4AllWebUI(app, args) + bot = Gpt4AllWebUI(app, config) - if args.debug: - app.run(debug=True, host=args.host, port=args.port) + if config["debug"]: + app.run(debug=True, host=config["host"], port=config["port"]) else: - app.run(host=args.host, port=args.port) + app.run(host=config["host"], port=config["port"]) diff --git a/config.py b/config.py new file mode 100644 index 00000000..8f211d8e --- /dev/null +++ b/config.py @@ -0,0 +1,22 @@ +###### +# Project : GPT4ALL-UI +# File : config.py +# Author : ParisNeo with the help of the community +# Supported by Nomic-AI +# Licence : Apache 2.0 +# Description : +# A front end Flask application for llamacpp models. +# The official GPT4All Web ui +# Made by the community for the community +###### +import yaml + +def load_config(file_path): + with open(file_path, 'r') as stream: + config = yaml.safe_load(stream) + return config + + +def save_config(config, filepath): + with open(filepath, "w") as f: + yaml.dump(config, f) \ No newline at end of file diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 00000000..9be3eae0 --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,13 @@ +seed: 0 +model: "gpt4all-lora-quantized-ggml.bin" +temp: 0.1 +n_predict: 256 +top_k: 40 +top_p: 0.95 +repeat_penalty: 1.3 +repeat_last_n: 64 +ctx_size: 512 +debug: false +host: "localhost" +port: 9600 +db_path: "database.db" diff --git a/requirements.txt b/requirements.txt index a3275191..8a16e560 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ flask nomic pytest pyllamacpp +pyyaml \ No newline at end of file diff --git a/templates/chat.html b/templates/chat.html index e0674dfd..c0aac2fc 100644 --- a/templates/chat.html +++ b/templates/chat.html @@ -69,7 +69,7 @@