mirror of
https://github.com/ParisNeo/lollms-webui.git
synced 2024-12-19 04:17:52 +00:00
Merge pull request #30 from NJannasch/feature/python-cleanup
Update python code (isort, black, pylint) and some manual tuning
This commit is contained in:
commit
b7c2c225d3
391
app.py
391
app.py
@ -1,129 +1,150 @@
|
||||
from flask import Flask, jsonify, request, render_template, Response, stream_with_context
|
||||
from pyllamacpp.model import Model
|
||||
import argparse
|
||||
import threading
|
||||
from io import StringIO
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from flask import (
|
||||
Flask,
|
||||
Response,
|
||||
jsonify,
|
||||
render_template,
|
||||
request,
|
||||
stream_with_context,
|
||||
)
|
||||
from pyllamacpp.model import Model
|
||||
|
||||
import select
|
||||
|
||||
#=================================== Database ==================================================================
|
||||
# =================================== Database ==================================================================
|
||||
class Discussion:
|
||||
def __init__(self, discussion_id, db_path='database.db'):
|
||||
def __init__(self, discussion_id, db_path="database.db"):
|
||||
self.discussion_id = discussion_id
|
||||
self.db_path = db_path
|
||||
|
||||
@staticmethod
|
||||
def create_discussion(db_path='database.db', title='untitled'):
|
||||
def create_discussion(db_path="database.db", title="untitled"):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute("INSERT INTO discussion (title) VALUES (?)", (title,))
|
||||
discussion_id = cur.lastrowid
|
||||
conn.commit()
|
||||
return Discussion(discussion_id, db_path)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_discussion(db_path='database.db', id=0):
|
||||
return Discussion(id, db_path)
|
||||
def get_discussion(db_path="database.db", discussion_id=0):
|
||||
return Discussion(discussion_id, db_path)
|
||||
|
||||
def add_message(self, sender, content):
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)',
|
||||
(sender, content, self.discussion_id))
|
||||
cur.execute(
|
||||
"INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)",
|
||||
(sender, content, self.discussion_id),
|
||||
)
|
||||
message_id = cur.lastrowid
|
||||
conn.commit()
|
||||
return message_id
|
||||
|
||||
@staticmethod
|
||||
def get_discussions(db_path):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('SELECT * FROM discussion')
|
||||
cursor.execute("SELECT * FROM discussion")
|
||||
rows = cursor.fetchall()
|
||||
return [{'id': row[0], 'title': row[1]} for row in rows]
|
||||
return [{"id": row[0], "title": row[1]} for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def rename(db_path, discussion_id, title):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('UPDATE discussion SET title=? WHERE id=?', (title, discussion_id))
|
||||
cursor.execute(
|
||||
"UPDATE discussion SET title=? WHERE id=?", (title, discussion_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete_discussion(self):
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('DELETE FROM message WHERE discussion_id=?', (self.discussion_id,))
|
||||
cur.execute('DELETE FROM discussion WHERE id=?', (self.discussion_id,))
|
||||
cur.execute(
|
||||
"DELETE FROM message WHERE discussion_id=?", (self.discussion_id,)
|
||||
)
|
||||
cur.execute("DELETE FROM discussion WHERE id=?", (self.discussion_id,))
|
||||
conn.commit()
|
||||
|
||||
def get_messages(self):
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('SELECT * FROM message WHERE discussion_id=?', (self.discussion_id,))
|
||||
cur.execute(
|
||||
"SELECT * FROM message WHERE discussion_id=?", (self.discussion_id,)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [{'sender': row[1], 'content': row[2], 'id':row[0]} for row in rows]
|
||||
|
||||
|
||||
return [{"sender": row[1], "content": row[2], "id": row[0]} for row in rows]
|
||||
|
||||
def update_message(self, message_id, new_content):
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('UPDATE message SET content = ? WHERE id = ?', (new_content, message_id))
|
||||
cur.execute(
|
||||
"UPDATE message SET content = ? WHERE id = ?", (new_content, message_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def remove_discussion(self):
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.cursor().execute('DELETE FROM discussion WHERE id=?', (self.discussion_id,))
|
||||
conn.cursor().execute(
|
||||
"DELETE FROM discussion WHERE id=?", (self.discussion_id,)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def last_discussion_has_messages(db_path='database.db'):
|
||||
|
||||
def last_discussion_has_messages(db_path="database.db"):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
c = conn.cursor()
|
||||
c.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
|
||||
last_message = c.fetchone()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
|
||||
last_message = cursor.fetchone()
|
||||
return last_message is not None
|
||||
|
||||
def export_to_json(db_path='database.db'):
|
||||
|
||||
def export_to_json(db_path="database.db"):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('SELECT * FROM discussion')
|
||||
cur.execute("SELECT * FROM discussion")
|
||||
discussions = []
|
||||
for row in cur.fetchall():
|
||||
discussion_id = row[0]
|
||||
discussion = {'id': discussion_id, 'messages': []}
|
||||
cur.execute('SELECT * FROM message WHERE discussion_id=?', (discussion_id,))
|
||||
discussion = {"id": discussion_id, "messages": []}
|
||||
cur.execute("SELECT * FROM message WHERE discussion_id=?", (discussion_id,))
|
||||
for message_row in cur.fetchall():
|
||||
discussion['messages'].append({'sender': message_row[1], 'content': message_row[2]})
|
||||
discussion["messages"].append(
|
||||
{"sender": message_row[1], "content": message_row[2]}
|
||||
)
|
||||
discussions.append(discussion)
|
||||
return discussions
|
||||
|
||||
def remove_discussions(db_path='database.db'):
|
||||
|
||||
def remove_discussions(db_path="database.db"):
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('DELETE FROM message')
|
||||
cur.execute('DELETE FROM discussion')
|
||||
cur.execute("DELETE FROM message")
|
||||
cur.execute("DELETE FROM discussion")
|
||||
conn.commit()
|
||||
|
||||
|
||||
# create database schema
|
||||
def check_discussion_db(db_path):
|
||||
print("Checking discussions database...")
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cur = conn.cursor()
|
||||
cur.execute('''
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS discussion (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT
|
||||
)
|
||||
''')
|
||||
cur.execute('''
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS message (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
sender TEXT NOT NULL,
|
||||
@ -131,58 +152,78 @@ def check_discussion_db(db_path):
|
||||
discussion_id INTEGER NOT NULL,
|
||||
FOREIGN KEY (discussion_id) REFERENCES discussion(id)
|
||||
)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
print("Ok")
|
||||
|
||||
|
||||
# ========================================================================================================================
|
||||
|
||||
|
||||
app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
|
||||
|
||||
app = Flask("GPT4All-WebUI", static_url_path='/static', static_folder='static')
|
||||
class Gpt4AllWebUI():
|
||||
def __init__(self, chatbot_bindings, app, db_path='database.db') -> None:
|
||||
|
||||
class Gpt4AllWebUI:
|
||||
def __init__(self, chatbot_bindings, _app, db_path="database.db") -> None:
|
||||
self.current_discussion = None
|
||||
self.chatbot_bindings = chatbot_bindings
|
||||
self.app=app
|
||||
self.db_path= db_path
|
||||
self.add_endpoint('/', '', self.index, methods=['GET'])
|
||||
self.add_endpoint('/export', 'export', self.export, methods=['GET'])
|
||||
self.add_endpoint('/new_discussion', 'new_discussion', self.new_discussion, methods=['GET'])
|
||||
self.add_endpoint('/bot', 'bot', self.bot, methods=['POST'])
|
||||
self.add_endpoint('/discussions', 'discussions', self.discussions, methods=['GET'])
|
||||
self.add_endpoint('/rename', 'rename', self.rename, methods=['POST'])
|
||||
self.add_endpoint('/get_messages', 'get_messages', self.get_messages, methods=['POST'])
|
||||
self.add_endpoint('/delete_discussion', 'delete_discussion', self.delete_discussion, methods=['POST'])
|
||||
self.app = _app
|
||||
self.db_path = db_path
|
||||
self.add_endpoint("/", "", self.index, methods=["GET"])
|
||||
self.add_endpoint("/export", "export", self.export, methods=["GET"])
|
||||
self.add_endpoint(
|
||||
"/new_discussion", "new_discussion", self.new_discussion, methods=["GET"]
|
||||
)
|
||||
self.add_endpoint("/bot", "bot", self.bot, methods=["POST"])
|
||||
self.add_endpoint(
|
||||
"/discussions", "discussions", self.discussions, methods=["GET"]
|
||||
)
|
||||
self.add_endpoint("/rename", "rename", self.rename, methods=["POST"])
|
||||
self.add_endpoint(
|
||||
"/get_messages", "get_messages", self.get_messages, methods=["POST"]
|
||||
)
|
||||
self.add_endpoint(
|
||||
"/delete_discussion",
|
||||
"delete_discussion",
|
||||
self.delete_discussion,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
self.add_endpoint('/update_message', 'update_message', self.update_message, methods=['GET'])
|
||||
|
||||
conditionning_message="""
|
||||
self.add_endpoint(
|
||||
"/update_message", "update_message", self.update_message, methods=["GET"]
|
||||
)
|
||||
|
||||
conditionning_message = """
|
||||
Instruction: Act as GPT4All. A kind and helpful AI bot built to help users solve problems.
|
||||
Start by welcoming the user then stop sending text.
|
||||
GPT4All:"""
|
||||
self.prepare_query(conditionning_message)
|
||||
chatbot_bindings.generate(conditionning_message, n_predict=55, new_text_callback=self.new_text_callback, n_threads=8)
|
||||
print(f"Bot said:{self.bot_says}")
|
||||
chatbot_bindings.generate(
|
||||
conditionning_message,
|
||||
n_predict=55,
|
||||
new_text_callback=self.new_text_callback,
|
||||
n_threads=8,
|
||||
)
|
||||
print(f"Bot said:{self.bot_says}")
|
||||
# Chatbot conditionning
|
||||
# response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:")
|
||||
# print(response)
|
||||
|
||||
def prepare_query(self, message):
|
||||
self.bot_says=''
|
||||
self.full_text=''
|
||||
self.is_bot_text_started=False
|
||||
self.bot_says = ""
|
||||
self.full_text = ""
|
||||
self.is_bot_text_started = False
|
||||
self.current_message = message
|
||||
|
||||
|
||||
def new_text_callback(self, text: str):
|
||||
print(text, end="")
|
||||
self.full_text += text
|
||||
if self.is_bot_text_started:
|
||||
self.bot_says += text
|
||||
if self.current_message in self.full_text:
|
||||
self.is_bot_text_started=True
|
||||
self.is_bot_text_started = True
|
||||
|
||||
def new_text_callback_with_yield(self, text: str):
|
||||
"""
|
||||
@ -193,14 +234,24 @@ GPT4All:"""
|
||||
if self.is_bot_text_started:
|
||||
self.bot_says += text
|
||||
if self.current_message in self.full_text:
|
||||
self.is_bot_text_started=True
|
||||
self.is_bot_text_started = True
|
||||
yield text
|
||||
|
||||
def add_endpoint(self, endpoint=None, endpoint_name=None, handler=None, methods=['GET'], *args, **kwargs):
|
||||
self.app.add_url_rule(endpoint, endpoint_name, handler, methods=methods, *args, **kwargs)
|
||||
def add_endpoint(
|
||||
self,
|
||||
endpoint=None,
|
||||
endpoint_name=None,
|
||||
handler=None,
|
||||
methods=["GET"],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.app.add_url_rule(
|
||||
endpoint, endpoint_name, handler, methods=methods, *args, **kwargs
|
||||
)
|
||||
|
||||
def index(self):
|
||||
return render_template('chat.html')
|
||||
return render_template("chat.html")
|
||||
|
||||
def format_message(self, message):
|
||||
# Look for a code block within the message
|
||||
@ -220,120 +271,192 @@ GPT4All:"""
|
||||
|
||||
@stream_with_context
|
||||
def parse_to_prompt_stream(self, message, message_id):
|
||||
bot_says = ''
|
||||
self.stop=False
|
||||
bot_says = ""
|
||||
self.stop = False
|
||||
|
||||
# send the message to the bot
|
||||
print(f"Received message : {message}")
|
||||
# First we need to send the new message ID to the client
|
||||
response_id = self.current_discussion.add_message("GPT4All",'') # first the content is empty, but we'll fill it at the end
|
||||
yield(json.dumps({'type':'input_message_infos','message':message, 'id':message_id, 'response_id':response_id}))
|
||||
response_id = self.current_discussion.add_message(
|
||||
"GPT4All", ""
|
||||
) # first the content is empty, but we'll fill it at the end
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
"type": "input_message_infos",
|
||||
"message": message,
|
||||
"id": message_id,
|
||||
"response_id": response_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
self.current_message = "User: "+message+"\nGPT4All:"
|
||||
self.current_message = "User: " + message + "\nGPT4All:"
|
||||
self.prepare_query(self.current_message)
|
||||
chatbot_bindings.generate(self.current_message, n_predict=55, new_text_callback=self.new_text_callback, n_threads=8)
|
||||
chatbot_model_bindings.generate(
|
||||
self.current_message,
|
||||
n_predict=55,
|
||||
new_text_callback=self.new_text_callback,
|
||||
n_threads=8,
|
||||
)
|
||||
|
||||
self.current_discussion.update_message(response_id,self.bot_says)
|
||||
self.current_discussion.update_message(response_id, self.bot_says)
|
||||
yield self.bot_says
|
||||
# TODO : change this to use the yield version in order to send text word by word
|
||||
|
||||
return "\n".join(bot_says)
|
||||
|
||||
|
||||
def bot(self):
|
||||
self.stop=True
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
try:
|
||||
if self.current_discussion is None or not last_discussion_has_messages(self.db_path):
|
||||
self.current_discussion=Discussion.create_discussion(self.db_path)
|
||||
self.stop = True
|
||||
|
||||
message_id = self.current_discussion.add_message("user", request.json['message'])
|
||||
message = f"{request.json['message']}"
|
||||
try:
|
||||
if self.current_discussion is None or not last_discussion_has_messages(
|
||||
self.db_path
|
||||
):
|
||||
self.current_discussion = Discussion.create_discussion(self.db_path)
|
||||
|
||||
message_id = self.current_discussion.add_message(
|
||||
"user", request.json["message"]
|
||||
)
|
||||
message = f"{request.json['message']}"
|
||||
|
||||
# Segmented (the user receives the output as it comes)
|
||||
# We will first send a json entry that contains the message id and so on, then the text as it goes
|
||||
return Response(
|
||||
stream_with_context(
|
||||
self.parse_to_prompt_stream(message, message_id)
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
return (
|
||||
"<b style='color:red;'>Exception :<b>"
|
||||
+ str(ex)
|
||||
+ "<br>"
|
||||
+ traceback.format_exc()
|
||||
+ "<br>Please report exception"
|
||||
)
|
||||
|
||||
# Segmented (the user receives the output as it comes)
|
||||
# We will first send a json entry that contains the message id and so on, then the text as it goes
|
||||
return Response(stream_with_context(self.parse_to_prompt_stream(message, message_id)))
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
msg = traceback.print_exc()
|
||||
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
||||
|
||||
def discussions(self):
|
||||
try:
|
||||
discussions = Discussion.get_discussions(self.db_path)
|
||||
discussions = Discussion.get_discussions(self.db_path)
|
||||
return jsonify(discussions)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
msg = traceback.print_exc()
|
||||
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
||||
return (
|
||||
"<b style='color:red;'>Exception :<b>"
|
||||
+ str(ex)
|
||||
+ "<br>"
|
||||
+ traceback.format_exc()
|
||||
+ "<br>Please report exception"
|
||||
)
|
||||
|
||||
def rename(self):
|
||||
data = request.get_json()
|
||||
id = data['id']
|
||||
title = data['title']
|
||||
Discussion.rename(self.db_path, id, title)
|
||||
discussion_id = data["id"]
|
||||
title = data["title"]
|
||||
Discussion.rename(self.db_path, discussion_id, title)
|
||||
return "renamed successfully"
|
||||
|
||||
def get_messages(self):
|
||||
data = request.get_json()
|
||||
id = data['id']
|
||||
self.current_discussion = Discussion(id,self.db_path)
|
||||
discussion_id = data["id"]
|
||||
self.current_discussion = Discussion(discussion_id, self.db_path)
|
||||
messages = self.current_discussion.get_messages()
|
||||
return jsonify(messages)
|
||||
|
||||
|
||||
def delete_discussion(self):
|
||||
data = request.get_json()
|
||||
id = data['id']
|
||||
self.current_discussion = Discussion(id, self.db_path)
|
||||
discussion_id = data["id"]
|
||||
self.current_discussion = Discussion(discussion_id, self.db_path)
|
||||
self.current_discussion.delete_discussion()
|
||||
self.current_discussion = None
|
||||
return jsonify({})
|
||||
|
||||
|
||||
def update_message(self):
|
||||
try:
|
||||
id = request.args.get('id')
|
||||
new_message = request.args.get('message')
|
||||
self.current_discussion.update_message(id, new_message)
|
||||
return jsonify({"status":'ok'})
|
||||
discussion_id = request.args.get("id")
|
||||
new_message = request.args.get("message")
|
||||
self.current_discussion.update_message(discussion_id, new_message)
|
||||
return jsonify({"status": "ok"})
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
msg = traceback.print_exc()
|
||||
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
||||
return (
|
||||
"<b style='color:red;'>Exception :<b>"
|
||||
+ str(ex)
|
||||
+ "<br>"
|
||||
+ traceback.format_exc()
|
||||
+ "<br>Please report exception"
|
||||
)
|
||||
|
||||
def new_discussion(self):
|
||||
title = request.args.get('title')
|
||||
self.current_discussion= Discussion.create_discussion(self.db_path, title)
|
||||
title = request.args.get("title")
|
||||
self.current_discussion = Discussion.create_discussion(self.db_path, title)
|
||||
# Get the current timestamp
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# add a new discussion
|
||||
self.chatbot_bindings.close()
|
||||
self.chatbot_bindings.open()
|
||||
# self.chatbot_bindings.close()
|
||||
# self.chatbot_bindings.open()
|
||||
|
||||
# Return a success response
|
||||
return json.dumps({'id': self.current_discussion.discussion_id})
|
||||
return json.dumps({"id": self.current_discussion.discussion_id, "time": timestamp})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Start the chatbot Flask app.')
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Start the chatbot Flask app.")
|
||||
|
||||
parser.add_argument('--temp', type=float, default=0.1, help='Temperature parameter for the model.')
|
||||
parser.add_argument('--n_predict', type=int, default=128, help='Number of tokens to predict at each step.')
|
||||
parser.add_argument('--top_k', type=int, default=40, help='Value for the top-k sampling.')
|
||||
parser.add_argument('--top_p', type=float, default=0.95, help='Value for the top-p sampling.')
|
||||
parser.add_argument('--repeat_penalty', type=float, default=1.3, help='Penalty for repeated tokens.')
|
||||
parser.add_argument('--repeat_last_n', type=int, default=64, help='Number of previous tokens to consider for the repeat penalty.')
|
||||
parser.add_argument('--ctx_size', type=int, default=2048, help='Size of the context window for the model.')
|
||||
parser.add_argument('--debug', dest='debug', action='store_true', help='launch Flask server in debug mode')
|
||||
parser.add_argument('--host', type=str, default='localhost', help='the hostname to listen on')
|
||||
parser.add_argument('--port', type=int, default=9600, help='the port to listen on')
|
||||
parser.add_argument('--db_path', type=str, default='database.db', help='Database path')
|
||||
parser.add_argument(
|
||||
"--temp", type=float, default=0.1, help="Temperature parameter for the model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_predict",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Number of tokens to predict at each step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k", type=int, default=40, help="Value for the top-k sampling."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p", type=float, default=0.95, help="Value for the top-p sampling."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat_penalty", type=float, default=1.3, help="Penalty for repeated tokens."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat_last_n",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of previous tokens to consider for the repeat penalty.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ctx_size",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Size of the context window for the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
dest="debug",
|
||||
action="store_true",
|
||||
help="launch Flask server in debug mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="localhost", help="the hostname to listen on"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=9600, help="the port to listen on")
|
||||
parser.add_argument(
|
||||
"--db_path", type=str, default="database.db", help="Database path"
|
||||
)
|
||||
parser.set_defaults(debug=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
chatbot_bindings = Model(ggml_model='./models/gpt4all-lora-quantized-ggml.bin', n_ctx=512)
|
||||
|
||||
chatbot_model_bindings = Model(
|
||||
ggml_model="./models/gpt4all-lora-quantized-ggml.bin", n_ctx=512
|
||||
)
|
||||
|
||||
# Old Code
|
||||
# GPT4All(decoder_config = {
|
||||
# 'temp': args.temp,
|
||||
@ -346,7 +469,7 @@ if __name__ == '__main__':
|
||||
# 'ctx_size': args.ctx_size
|
||||
# })
|
||||
check_discussion_db(args.db_path)
|
||||
bot = Gpt4AllWebUI(chatbot_bindings, app, args.db_path)
|
||||
bot = Gpt4AllWebUI(chatbot_model_bindings, app, args.db_path)
|
||||
|
||||
if args.debug:
|
||||
app.run(debug=True, host=args.host, port=args.port)
|
||||
|
@ -1,12 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from app import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with app.test_client() as client:
|
||||
yield client
|
||||
|
||||
|
||||
def test_homepage(client):
|
||||
response = client.get('/')
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert b"Welcome to my Flask app" in response.data
|
||||
|
Loading…
Reference in New Issue
Block a user