mirror of
https://github.com/ParisNeo/lollms-webui.git
synced 2024-12-18 20:17:50 +00:00
356 lines
14 KiB
Python
356 lines
14 KiB
Python
from flask import Flask, jsonify, request, render_template, Response, stream_with_context
|
|
from nomic.gpt4all import GPT4All
|
|
import argparse
|
|
import threading
|
|
from io import StringIO
|
|
import sys
|
|
import re
|
|
import sqlite3
|
|
from datetime import datetime
|
|
|
|
import sqlite3
|
|
import json
|
|
import time
|
|
import traceback
|
|
|
|
import select
|
|
|
|
#=================================== Database ==================================================================
|
|
class Discussion:
|
|
def __init__(self, discussion_id, db_path='database.db'):
|
|
self.discussion_id = discussion_id
|
|
self.db_path = db_path
|
|
|
|
@staticmethod
|
|
def create_discussion(db_path='database.db', title='untitled'):
|
|
with sqlite3.connect(db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute("INSERT INTO discussion (title) VALUES (?)", (title,))
|
|
discussion_id = cur.lastrowid
|
|
conn.commit()
|
|
return Discussion(discussion_id, db_path)
|
|
|
|
@staticmethod
|
|
def get_discussion(db_path='database.db', id=0):
|
|
return Discussion(id, db_path)
|
|
|
|
def add_message(self, sender, content):
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)',
|
|
(sender, content, self.discussion_id))
|
|
message_id = cur.lastrowid
|
|
conn.commit()
|
|
return message_id
|
|
@staticmethod
|
|
def get_discussions(db_path):
|
|
with sqlite3.connect(db_path) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute('SELECT * FROM discussion')
|
|
rows = cursor.fetchall()
|
|
return [{'id': row[0], 'title': row[1]} for row in rows]
|
|
|
|
@staticmethod
|
|
def rename(db_path, discussion_id, title):
|
|
with sqlite3.connect(db_path) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute('UPDATE discussion SET title=? WHERE id=?', (title, discussion_id))
|
|
conn.commit()
|
|
|
|
def delete_discussion(self):
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('DELETE FROM message WHERE discussion_id=?', (self.discussion_id,))
|
|
cur.execute('DELETE FROM discussion WHERE id=?', (self.discussion_id,))
|
|
conn.commit()
|
|
|
|
def get_messages(self):
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('SELECT * FROM message WHERE discussion_id=?', (self.discussion_id,))
|
|
rows = cur.fetchall()
|
|
return [{'sender': row[1], 'content': row[2], 'id':row[0]} for row in rows]
|
|
|
|
|
|
|
|
def update_message(self, message_id, new_content):
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('UPDATE message SET content = ? WHERE id = ?', (new_content, message_id))
|
|
conn.commit()
|
|
|
|
def remove_discussion(self):
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.cursor().execute('DELETE FROM discussion WHERE id=?', (self.discussion_id,))
|
|
conn.commit()
|
|
|
|
def last_discussion_has_messages(db_path='database.db'):
|
|
with sqlite3.connect(db_path) as conn:
|
|
c = conn.cursor()
|
|
c.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
|
|
last_message = c.fetchone()
|
|
return last_message is not None
|
|
|
|
def export_to_json(db_path='database.db'):
|
|
with sqlite3.connect(db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('SELECT * FROM discussion')
|
|
discussions = []
|
|
for row in cur.fetchall():
|
|
discussion_id = row[0]
|
|
discussion = {'id': discussion_id, 'messages': []}
|
|
cur.execute('SELECT * FROM message WHERE discussion_id=?', (discussion_id,))
|
|
for message_row in cur.fetchall():
|
|
discussion['messages'].append({'sender': message_row[1], 'content': message_row[2]})
|
|
discussions.append(discussion)
|
|
return discussions
|
|
|
|
def remove_discussions(db_path='database.db'):
|
|
with sqlite3.connect(db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('DELETE FROM message')
|
|
cur.execute('DELETE FROM discussion')
|
|
conn.commit()
|
|
|
|
# create database schema
|
|
def check_discussion_db(db_path):
|
|
print("Checking discussions database...")
|
|
with sqlite3.connect(db_path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute('''
|
|
CREATE TABLE IF NOT EXISTS discussion (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
title TEXT
|
|
)
|
|
''')
|
|
cur.execute('''
|
|
CREATE TABLE IF NOT EXISTS message (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
sender TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
discussion_id INTEGER NOT NULL,
|
|
FOREIGN KEY (discussion_id) REFERENCES discussion(id)
|
|
)
|
|
''')
|
|
conn.commit()
|
|
|
|
print("Ok")
|
|
|
|
# ========================================================================================================================
|
|
|
|
|
|
|
|
app = Flask("GPT4All-WebUI", static_url_path='/static', static_folder='static')
|
|
class Gpt4AllWebUI():
|
|
def __init__(self, chatbot_bindings, app, db_path='database.db') -> None:
|
|
self.current_discussion = None
|
|
self.chatbot_bindings = chatbot_bindings
|
|
self.app=app
|
|
self.db_path= db_path
|
|
self.add_endpoint('/', '', self.index, methods=['GET'])
|
|
self.add_endpoint('/stream', 'stream', self.stream, methods=['GET'])
|
|
self.add_endpoint('/export', 'export', self.export, methods=['GET'])
|
|
self.add_endpoint('/new_discussion', 'new_discussion', self.new_discussion, methods=['GET'])
|
|
self.add_endpoint('/bot', 'bot', self.bot, methods=['POST'])
|
|
self.add_endpoint('/discussions', 'discussions', self.discussions, methods=['GET'])
|
|
self.add_endpoint('/rename', 'rename', self.rename, methods=['POST'])
|
|
self.add_endpoint('/get_messages', 'get_messages', self.get_messages, methods=['POST'])
|
|
self.add_endpoint('/delete_discussion', 'delete_discussion', self.delete_discussion, methods=['POST'])
|
|
|
|
self.add_endpoint('/update_message', 'update_message', self.update_message, methods=['GET'])
|
|
|
|
|
|
|
|
# Chatbot conditionning
|
|
# response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:")
|
|
# print(response)
|
|
|
|
def add_endpoint(self, endpoint=None, endpoint_name=None, handler=None, methods=['GET'], *args, **kwargs):
|
|
self.app.add_url_rule(endpoint, endpoint_name, handler, methods=methods, *args, **kwargs)
|
|
|
|
def index(self):
|
|
return render_template('chat.html')
|
|
|
|
def format_message(self, message):
|
|
# Look for a code block within the message
|
|
pattern = re.compile(r"(```.*?```)", re.DOTALL)
|
|
match = pattern.search(message)
|
|
|
|
# If a code block is found, replace it with a <code> tag
|
|
if match:
|
|
code_block = match.group(1)
|
|
message = message.replace(code_block, f"<code>{code_block[3:-3]}</code>")
|
|
|
|
# Return the formatted message
|
|
return message
|
|
|
|
|
|
def stream(self):
|
|
def generate():
|
|
# Replace this with your text-generating code
|
|
for i in range(10):
|
|
yield f'This is line {i+1}\n'
|
|
time.sleep(1)
|
|
|
|
return Response(stream_with_context(generate()))
|
|
|
|
def export(self):
|
|
return jsonify(export_to_json(self.db_path))
|
|
|
|
|
|
@stream_with_context
|
|
def parse_to_prompt_stream(self, message, message_id):
|
|
bot_says = ['']
|
|
point = b''
|
|
bot = self.chatbot_bindings.bot
|
|
self.stop=False
|
|
|
|
# very important. This is the maximum time we wait for the model
|
|
wait_val = 15.0 # At the beginning the server may need time to send data. we wait 15s
|
|
|
|
# send the message to the bot
|
|
print(f"Received message : {message}")
|
|
bot = self.chatbot_bindings.bot
|
|
bot.stdin.write(message.encode('utf-8'))
|
|
bot.stdin.write(b"\n")
|
|
bot.stdin.flush()
|
|
|
|
# First we need to send the new message ID to the client
|
|
response_id = self.current_discussion.add_message("GPT4All",'') # first the content is empty, but we'll fill it at the end
|
|
yield(json.dumps({'type':'input_message_infos','message':message, 'id':message_id, 'response_id':response_id}))
|
|
|
|
#Now let's wait for the bot to answer
|
|
while not self.stop:
|
|
readable, _, _ = select.select([bot.stdout], [], [], wait_val)
|
|
wait_val = 4.0 # Once started, the process doesn't take that much so we reduce the wait
|
|
if bot.stdout in readable:
|
|
point += bot.stdout.read(1)
|
|
try:
|
|
character = point.decode("utf-8")
|
|
if character == "\n":
|
|
bot_says.append('\n')
|
|
yield '\n'
|
|
else:
|
|
bot_says[-1] += character
|
|
yield character
|
|
point = b''
|
|
|
|
except UnicodeDecodeError:
|
|
if len(point) > 4:
|
|
point = b''
|
|
else:
|
|
self.current_discussion.update_message(response_id,bot_says)
|
|
return "\n".join(bot_says)
|
|
|
|
def bot(self):
|
|
self.stop=True
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
try:
|
|
if self.current_discussion is None or not last_discussion_has_messages(self.db_path):
|
|
self.current_discussion=Discussion.create_discussion(self.db_path)
|
|
|
|
message_id = self.current_discussion.add_message("user", request.json['message'])
|
|
message = f"{request.json['message']}"
|
|
|
|
# Segmented (the user receives the output as it comes)
|
|
# We will first send a json entry that contains the message id and so on, then the text as it goes
|
|
return Response(stream_with_context(self.parse_to_prompt_stream(message, message_id)))
|
|
except Exception as ex:
|
|
print(ex)
|
|
msg = traceback.print_exc()
|
|
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
|
|
|
def discussions(self):
|
|
try:
|
|
discussions = Discussion.get_discussions(self.db_path)
|
|
return jsonify(discussions)
|
|
except Exception as ex:
|
|
print(ex)
|
|
msg = traceback.print_exc()
|
|
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
|
|
|
def rename(self):
|
|
data = request.get_json()
|
|
id = data['id']
|
|
title = data['title']
|
|
Discussion.rename(self.db_path, id, title)
|
|
return "renamed successfully"
|
|
|
|
def get_messages(self):
|
|
data = request.get_json()
|
|
id = data['id']
|
|
self.current_discussion = Discussion(id,self.db_path)
|
|
messages = self.current_discussion.get_messages()
|
|
return jsonify(messages)
|
|
|
|
|
|
def delete_discussion(self):
|
|
data = request.get_json()
|
|
id = data['id']
|
|
self.current_discussion = Discussion(id, self.db_path)
|
|
self.current_discussion.delete_discussion()
|
|
self.current_discussion = None
|
|
return jsonify({})
|
|
|
|
def update_message(self):
|
|
try:
|
|
id = request.args.get('id')
|
|
new_message = request.args.get('message')
|
|
self.current_discussion.update_message(id, new_message)
|
|
return jsonify({"status":'ok'})
|
|
except Exception as ex:
|
|
print(ex)
|
|
msg = traceback.print_exc()
|
|
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
|
|
|
def new_discussion(self):
|
|
title = request.args.get('title')
|
|
self.current_discussion= Discussion.create_discussion(self.db_path, title)
|
|
# Get the current timestamp
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
# add a new discussion
|
|
self.chatbot_bindings.close()
|
|
self.chatbot_bindings.open()
|
|
|
|
# Return a success response
|
|
return json.dumps({'id': self.current_discussion.discussion_id})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Start the chatbot Flask app.')
|
|
|
|
parser.add_argument('--temp', type=float, default=0.1, help='Temperature parameter for the model.')
|
|
parser.add_argument('--n_predict', type=int, default=128, help='Number of tokens to predict at each step.')
|
|
parser.add_argument('--top_k', type=int, default=40, help='Value for the top-k sampling.')
|
|
parser.add_argument('--top_p', type=float, default=0.95, help='Value for the top-p sampling.')
|
|
parser.add_argument('--repeat_penalty', type=float, default=1.3, help='Penalty for repeated tokens.')
|
|
parser.add_argument('--repeat_last_n', type=int, default=64, help='Number of previous tokens to consider for the repeat penalty.')
|
|
parser.add_argument('--ctx_size', type=int, default=2048, help='Size of the context window for the model.')
|
|
parser.add_argument('--debug', dest='debug', action='store_true', help='launch Flask server in debug mode')
|
|
parser.add_argument('--host', type=str, default='localhost', help='the hostname to listen on')
|
|
parser.add_argument('--port', type=int, default=9600, help='the port to listen on')
|
|
parser.add_argument('--db_path', type=str, default='database.db', help='Database path')
|
|
parser.set_defaults(debug=False)
|
|
|
|
args = parser.parse_args()
|
|
|
|
chatbot_bindings = GPT4All(decoder_config = {
|
|
'temp': args.temp,
|
|
'n_predict':args.n_predict,
|
|
'top_k':args.top_k,
|
|
'top_p':args.top_p,
|
|
#'color': True,#"## Instruction",
|
|
'repeat_penalty': args.repeat_penalty,
|
|
'repeat_last_n':args.repeat_last_n,
|
|
'ctx_size': args.ctx_size
|
|
})
|
|
chatbot_bindings.open()
|
|
check_discussion_db(args.db_path)
|
|
bot = Gpt4AllWebUI(chatbot_bindings, app, args.db_path)
|
|
|
|
if args.debug:
|
|
app.run(debug=True, host=args.host, port=args.port)
|
|
else:
|
|
app.run(host=args.host, port=args.port)
|