mirror of
https://github.com/ParisNeo/lollms-webui.git
synced 2024-12-21 13:17:47 +00:00
Merge pull request #30 from NJannasch/feature/python-cleanup
Update python code (isort, black, pylint) and some manual tuning
This commit is contained in:
commit
b7c2c225d3
391
app.py
391
app.py
@ -1,129 +1,150 @@
|
|||||||
from flask import Flask, jsonify, request, render_template, Response, stream_with_context
|
|
||||||
from pyllamacpp.model import Model
|
|
||||||
import argparse
|
import argparse
|
||||||
import threading
|
import json
|
||||||
from io import StringIO
|
|
||||||
import sys
|
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import sqlite3
|
from flask import (
|
||||||
import json
|
Flask,
|
||||||
import time
|
Response,
|
||||||
import traceback
|
jsonify,
|
||||||
|
render_template,
|
||||||
|
request,
|
||||||
|
stream_with_context,
|
||||||
|
)
|
||||||
|
from pyllamacpp.model import Model
|
||||||
|
|
||||||
import select
|
|
||||||
|
|
||||||
#=================================== Database ==================================================================
|
# =================================== Database ==================================================================
|
||||||
class Discussion:
|
class Discussion:
|
||||||
def __init__(self, discussion_id, db_path='database.db'):
|
def __init__(self, discussion_id, db_path="database.db"):
|
||||||
self.discussion_id = discussion_id
|
self.discussion_id = discussion_id
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_discussion(db_path='database.db', title='untitled'):
|
def create_discussion(db_path="database.db", title="untitled"):
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute("INSERT INTO discussion (title) VALUES (?)", (title,))
|
cur.execute("INSERT INTO discussion (title) VALUES (?)", (title,))
|
||||||
discussion_id = cur.lastrowid
|
discussion_id = cur.lastrowid
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return Discussion(discussion_id, db_path)
|
return Discussion(discussion_id, db_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_discussion(db_path='database.db', id=0):
|
def get_discussion(db_path="database.db", discussion_id=0):
|
||||||
return Discussion(id, db_path)
|
return Discussion(discussion_id, db_path)
|
||||||
|
|
||||||
def add_message(self, sender, content):
|
def add_message(self, sender, content):
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)',
|
cur.execute(
|
||||||
(sender, content, self.discussion_id))
|
"INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)",
|
||||||
|
(sender, content, self.discussion_id),
|
||||||
|
)
|
||||||
message_id = cur.lastrowid
|
message_id = cur.lastrowid
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_discussions(db_path):
|
def get_discussions(db_path):
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('SELECT * FROM discussion')
|
cursor.execute("SELECT * FROM discussion")
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
return [{'id': row[0], 'title': row[1]} for row in rows]
|
return [{"id": row[0], "title": row[1]} for row in rows]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rename(db_path, discussion_id, title):
|
def rename(db_path, discussion_id, title):
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute('UPDATE discussion SET title=? WHERE id=?', (title, discussion_id))
|
cursor.execute(
|
||||||
|
"UPDATE discussion SET title=? WHERE id=?", (title, discussion_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def delete_discussion(self):
|
def delete_discussion(self):
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('DELETE FROM message WHERE discussion_id=?', (self.discussion_id,))
|
cur.execute(
|
||||||
cur.execute('DELETE FROM discussion WHERE id=?', (self.discussion_id,))
|
"DELETE FROM message WHERE discussion_id=?", (self.discussion_id,)
|
||||||
|
)
|
||||||
|
cur.execute("DELETE FROM discussion WHERE id=?", (self.discussion_id,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def get_messages(self):
|
def get_messages(self):
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('SELECT * FROM message WHERE discussion_id=?', (self.discussion_id,))
|
cur.execute(
|
||||||
|
"SELECT * FROM message WHERE discussion_id=?", (self.discussion_id,)
|
||||||
|
)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
return [{'sender': row[1], 'content': row[2], 'id':row[0]} for row in rows]
|
return [{"sender": row[1], "content": row[2], "id": row[0]} for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def update_message(self, message_id, new_content):
|
def update_message(self, message_id, new_content):
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('UPDATE message SET content = ? WHERE id = ?', (new_content, message_id))
|
cur.execute(
|
||||||
|
"UPDATE message SET content = ? WHERE id = ?", (new_content, message_id)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def remove_discussion(self):
|
def remove_discussion(self):
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.cursor().execute('DELETE FROM discussion WHERE id=?', (self.discussion_id,))
|
conn.cursor().execute(
|
||||||
|
"DELETE FROM discussion WHERE id=?", (self.discussion_id,)
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def last_discussion_has_messages(db_path='database.db'):
|
|
||||||
|
def last_discussion_has_messages(db_path="database.db"):
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
c = conn.cursor()
|
cursor = conn.cursor()
|
||||||
c.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
|
cursor.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
|
||||||
last_message = c.fetchone()
|
last_message = cursor.fetchone()
|
||||||
return last_message is not None
|
return last_message is not None
|
||||||
|
|
||||||
def export_to_json(db_path='database.db'):
|
|
||||||
|
def export_to_json(db_path="database.db"):
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('SELECT * FROM discussion')
|
cur.execute("SELECT * FROM discussion")
|
||||||
discussions = []
|
discussions = []
|
||||||
for row in cur.fetchall():
|
for row in cur.fetchall():
|
||||||
discussion_id = row[0]
|
discussion_id = row[0]
|
||||||
discussion = {'id': discussion_id, 'messages': []}
|
discussion = {"id": discussion_id, "messages": []}
|
||||||
cur.execute('SELECT * FROM message WHERE discussion_id=?', (discussion_id,))
|
cur.execute("SELECT * FROM message WHERE discussion_id=?", (discussion_id,))
|
||||||
for message_row in cur.fetchall():
|
for message_row in cur.fetchall():
|
||||||
discussion['messages'].append({'sender': message_row[1], 'content': message_row[2]})
|
discussion["messages"].append(
|
||||||
|
{"sender": message_row[1], "content": message_row[2]}
|
||||||
|
)
|
||||||
discussions.append(discussion)
|
discussions.append(discussion)
|
||||||
return discussions
|
return discussions
|
||||||
|
|
||||||
def remove_discussions(db_path='database.db'):
|
|
||||||
|
def remove_discussions(db_path="database.db"):
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('DELETE FROM message')
|
cur.execute("DELETE FROM message")
|
||||||
cur.execute('DELETE FROM discussion')
|
cur.execute("DELETE FROM discussion")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
# create database schema
|
# create database schema
|
||||||
def check_discussion_db(db_path):
|
def check_discussion_db(db_path):
|
||||||
print("Checking discussions database...")
|
print("Checking discussions database...")
|
||||||
with sqlite3.connect(db_path) as conn:
|
with sqlite3.connect(db_path) as conn:
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute('''
|
cur.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS discussion (
|
CREATE TABLE IF NOT EXISTS discussion (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
title TEXT
|
title TEXT
|
||||||
)
|
)
|
||||||
''')
|
"""
|
||||||
cur.execute('''
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS message (
|
CREATE TABLE IF NOT EXISTS message (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
sender TEXT NOT NULL,
|
sender TEXT NOT NULL,
|
||||||
@ -131,58 +152,78 @@ def check_discussion_db(db_path):
|
|||||||
discussion_id INTEGER NOT NULL,
|
discussion_id INTEGER NOT NULL,
|
||||||
FOREIGN KEY (discussion_id) REFERENCES discussion(id)
|
FOREIGN KEY (discussion_id) REFERENCES discussion(id)
|
||||||
)
|
)
|
||||||
''')
|
"""
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
print("Ok")
|
print("Ok")
|
||||||
|
|
||||||
|
|
||||||
# ========================================================================================================================
|
# ========================================================================================================================
|
||||||
|
|
||||||
|
|
||||||
|
app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
|
||||||
|
|
||||||
app = Flask("GPT4All-WebUI", static_url_path='/static', static_folder='static')
|
|
||||||
class Gpt4AllWebUI():
|
class Gpt4AllWebUI:
|
||||||
def __init__(self, chatbot_bindings, app, db_path='database.db') -> None:
|
def __init__(self, chatbot_bindings, _app, db_path="database.db") -> None:
|
||||||
self.current_discussion = None
|
self.current_discussion = None
|
||||||
self.chatbot_bindings = chatbot_bindings
|
self.chatbot_bindings = chatbot_bindings
|
||||||
self.app=app
|
self.app = _app
|
||||||
self.db_path= db_path
|
self.db_path = db_path
|
||||||
self.add_endpoint('/', '', self.index, methods=['GET'])
|
self.add_endpoint("/", "", self.index, methods=["GET"])
|
||||||
self.add_endpoint('/export', 'export', self.export, methods=['GET'])
|
self.add_endpoint("/export", "export", self.export, methods=["GET"])
|
||||||
self.add_endpoint('/new_discussion', 'new_discussion', self.new_discussion, methods=['GET'])
|
self.add_endpoint(
|
||||||
self.add_endpoint('/bot', 'bot', self.bot, methods=['POST'])
|
"/new_discussion", "new_discussion", self.new_discussion, methods=["GET"]
|
||||||
self.add_endpoint('/discussions', 'discussions', self.discussions, methods=['GET'])
|
)
|
||||||
self.add_endpoint('/rename', 'rename', self.rename, methods=['POST'])
|
self.add_endpoint("/bot", "bot", self.bot, methods=["POST"])
|
||||||
self.add_endpoint('/get_messages', 'get_messages', self.get_messages, methods=['POST'])
|
self.add_endpoint(
|
||||||
self.add_endpoint('/delete_discussion', 'delete_discussion', self.delete_discussion, methods=['POST'])
|
"/discussions", "discussions", self.discussions, methods=["GET"]
|
||||||
|
)
|
||||||
|
self.add_endpoint("/rename", "rename", self.rename, methods=["POST"])
|
||||||
|
self.add_endpoint(
|
||||||
|
"/get_messages", "get_messages", self.get_messages, methods=["POST"]
|
||||||
|
)
|
||||||
|
self.add_endpoint(
|
||||||
|
"/delete_discussion",
|
||||||
|
"delete_discussion",
|
||||||
|
self.delete_discussion,
|
||||||
|
methods=["POST"],
|
||||||
|
)
|
||||||
|
|
||||||
self.add_endpoint('/update_message', 'update_message', self.update_message, methods=['GET'])
|
self.add_endpoint(
|
||||||
|
"/update_message", "update_message", self.update_message, methods=["GET"]
|
||||||
conditionning_message="""
|
)
|
||||||
|
|
||||||
|
conditionning_message = """
|
||||||
Instruction: Act as GPT4All. A kind and helpful AI bot built to help users solve problems.
|
Instruction: Act as GPT4All. A kind and helpful AI bot built to help users solve problems.
|
||||||
Start by welcoming the user then stop sending text.
|
Start by welcoming the user then stop sending text.
|
||||||
GPT4All:"""
|
GPT4All:"""
|
||||||
self.prepare_query(conditionning_message)
|
self.prepare_query(conditionning_message)
|
||||||
chatbot_bindings.generate(conditionning_message, n_predict=55, new_text_callback=self.new_text_callback, n_threads=8)
|
chatbot_bindings.generate(
|
||||||
print(f"Bot said:{self.bot_says}")
|
conditionning_message,
|
||||||
|
n_predict=55,
|
||||||
|
new_text_callback=self.new_text_callback,
|
||||||
|
n_threads=8,
|
||||||
|
)
|
||||||
|
print(f"Bot said:{self.bot_says}")
|
||||||
# Chatbot conditionning
|
# Chatbot conditionning
|
||||||
# response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:")
|
# response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:")
|
||||||
# print(response)
|
# print(response)
|
||||||
|
|
||||||
def prepare_query(self, message):
|
def prepare_query(self, message):
|
||||||
self.bot_says=''
|
self.bot_says = ""
|
||||||
self.full_text=''
|
self.full_text = ""
|
||||||
self.is_bot_text_started=False
|
self.is_bot_text_started = False
|
||||||
self.current_message = message
|
self.current_message = message
|
||||||
|
|
||||||
|
|
||||||
def new_text_callback(self, text: str):
|
def new_text_callback(self, text: str):
|
||||||
print(text, end="")
|
print(text, end="")
|
||||||
self.full_text += text
|
self.full_text += text
|
||||||
if self.is_bot_text_started:
|
if self.is_bot_text_started:
|
||||||
self.bot_says += text
|
self.bot_says += text
|
||||||
if self.current_message in self.full_text:
|
if self.current_message in self.full_text:
|
||||||
self.is_bot_text_started=True
|
self.is_bot_text_started = True
|
||||||
|
|
||||||
def new_text_callback_with_yield(self, text: str):
|
def new_text_callback_with_yield(self, text: str):
|
||||||
"""
|
"""
|
||||||
@ -193,14 +234,24 @@ GPT4All:"""
|
|||||||
if self.is_bot_text_started:
|
if self.is_bot_text_started:
|
||||||
self.bot_says += text
|
self.bot_says += text
|
||||||
if self.current_message in self.full_text:
|
if self.current_message in self.full_text:
|
||||||
self.is_bot_text_started=True
|
self.is_bot_text_started = True
|
||||||
yield text
|
yield text
|
||||||
|
|
||||||
def add_endpoint(self, endpoint=None, endpoint_name=None, handler=None, methods=['GET'], *args, **kwargs):
|
def add_endpoint(
|
||||||
self.app.add_url_rule(endpoint, endpoint_name, handler, methods=methods, *args, **kwargs)
|
self,
|
||||||
|
endpoint=None,
|
||||||
|
endpoint_name=None,
|
||||||
|
handler=None,
|
||||||
|
methods=["GET"],
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.app.add_url_rule(
|
||||||
|
endpoint, endpoint_name, handler, methods=methods, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def index(self):
|
def index(self):
|
||||||
return render_template('chat.html')
|
return render_template("chat.html")
|
||||||
|
|
||||||
def format_message(self, message):
|
def format_message(self, message):
|
||||||
# Look for a code block within the message
|
# Look for a code block within the message
|
||||||
@ -220,120 +271,192 @@ GPT4All:"""
|
|||||||
|
|
||||||
@stream_with_context
|
@stream_with_context
|
||||||
def parse_to_prompt_stream(self, message, message_id):
|
def parse_to_prompt_stream(self, message, message_id):
|
||||||
bot_says = ''
|
bot_says = ""
|
||||||
self.stop=False
|
self.stop = False
|
||||||
|
|
||||||
# send the message to the bot
|
# send the message to the bot
|
||||||
print(f"Received message : {message}")
|
print(f"Received message : {message}")
|
||||||
# First we need to send the new message ID to the client
|
# First we need to send the new message ID to the client
|
||||||
response_id = self.current_discussion.add_message("GPT4All",'') # first the content is empty, but we'll fill it at the end
|
response_id = self.current_discussion.add_message(
|
||||||
yield(json.dumps({'type':'input_message_infos','message':message, 'id':message_id, 'response_id':response_id}))
|
"GPT4All", ""
|
||||||
|
) # first the content is empty, but we'll fill it at the end
|
||||||
|
yield (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "input_message_infos",
|
||||||
|
"message": message,
|
||||||
|
"id": message_id,
|
||||||
|
"response_id": response_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.current_message = "User: "+message+"\nGPT4All:"
|
self.current_message = "User: " + message + "\nGPT4All:"
|
||||||
self.prepare_query(self.current_message)
|
self.prepare_query(self.current_message)
|
||||||
chatbot_bindings.generate(self.current_message, n_predict=55, new_text_callback=self.new_text_callback, n_threads=8)
|
chatbot_model_bindings.generate(
|
||||||
|
self.current_message,
|
||||||
|
n_predict=55,
|
||||||
|
new_text_callback=self.new_text_callback,
|
||||||
|
n_threads=8,
|
||||||
|
)
|
||||||
|
|
||||||
self.current_discussion.update_message(response_id,self.bot_says)
|
self.current_discussion.update_message(response_id, self.bot_says)
|
||||||
yield self.bot_says
|
yield self.bot_says
|
||||||
# TODO : change this to use the yield version in order to send text word by word
|
# TODO : change this to use the yield version in order to send text word by word
|
||||||
|
|
||||||
return "\n".join(bot_says)
|
return "\n".join(bot_says)
|
||||||
|
|
||||||
def bot(self):
|
def bot(self):
|
||||||
self.stop=True
|
self.stop = True
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
|
||||||
try:
|
|
||||||
if self.current_discussion is None or not last_discussion_has_messages(self.db_path):
|
|
||||||
self.current_discussion=Discussion.create_discussion(self.db_path)
|
|
||||||
|
|
||||||
message_id = self.current_discussion.add_message("user", request.json['message'])
|
try:
|
||||||
message = f"{request.json['message']}"
|
if self.current_discussion is None or not last_discussion_has_messages(
|
||||||
|
self.db_path
|
||||||
|
):
|
||||||
|
self.current_discussion = Discussion.create_discussion(self.db_path)
|
||||||
|
|
||||||
|
message_id = self.current_discussion.add_message(
|
||||||
|
"user", request.json["message"]
|
||||||
|
)
|
||||||
|
message = f"{request.json['message']}"
|
||||||
|
|
||||||
|
# Segmented (the user receives the output as it comes)
|
||||||
|
# We will first send a json entry that contains the message id and so on, then the text as it goes
|
||||||
|
return Response(
|
||||||
|
stream_with_context(
|
||||||
|
self.parse_to_prompt_stream(message, message_id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
print(ex)
|
||||||
|
return (
|
||||||
|
"<b style='color:red;'>Exception :<b>"
|
||||||
|
+ str(ex)
|
||||||
|
+ "<br>"
|
||||||
|
+ traceback.format_exc()
|
||||||
|
+ "<br>Please report exception"
|
||||||
|
)
|
||||||
|
|
||||||
# Segmented (the user receives the output as it comes)
|
|
||||||
# We will first send a json entry that contains the message id and so on, then the text as it goes
|
|
||||||
return Response(stream_with_context(self.parse_to_prompt_stream(message, message_id)))
|
|
||||||
except Exception as ex:
|
|
||||||
print(ex)
|
|
||||||
msg = traceback.print_exc()
|
|
||||||
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
|
||||||
|
|
||||||
def discussions(self):
|
def discussions(self):
|
||||||
try:
|
try:
|
||||||
discussions = Discussion.get_discussions(self.db_path)
|
discussions = Discussion.get_discussions(self.db_path)
|
||||||
return jsonify(discussions)
|
return jsonify(discussions)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
msg = traceback.print_exc()
|
return (
|
||||||
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
"<b style='color:red;'>Exception :<b>"
|
||||||
|
+ str(ex)
|
||||||
|
+ "<br>"
|
||||||
|
+ traceback.format_exc()
|
||||||
|
+ "<br>Please report exception"
|
||||||
|
)
|
||||||
|
|
||||||
def rename(self):
|
def rename(self):
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
id = data['id']
|
discussion_id = data["id"]
|
||||||
title = data['title']
|
title = data["title"]
|
||||||
Discussion.rename(self.db_path, id, title)
|
Discussion.rename(self.db_path, discussion_id, title)
|
||||||
return "renamed successfully"
|
return "renamed successfully"
|
||||||
|
|
||||||
def get_messages(self):
|
def get_messages(self):
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
id = data['id']
|
discussion_id = data["id"]
|
||||||
self.current_discussion = Discussion(id,self.db_path)
|
self.current_discussion = Discussion(discussion_id, self.db_path)
|
||||||
messages = self.current_discussion.get_messages()
|
messages = self.current_discussion.get_messages()
|
||||||
return jsonify(messages)
|
return jsonify(messages)
|
||||||
|
|
||||||
|
|
||||||
def delete_discussion(self):
|
def delete_discussion(self):
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
id = data['id']
|
discussion_id = data["id"]
|
||||||
self.current_discussion = Discussion(id, self.db_path)
|
self.current_discussion = Discussion(discussion_id, self.db_path)
|
||||||
self.current_discussion.delete_discussion()
|
self.current_discussion.delete_discussion()
|
||||||
self.current_discussion = None
|
self.current_discussion = None
|
||||||
return jsonify({})
|
return jsonify({})
|
||||||
|
|
||||||
def update_message(self):
|
def update_message(self):
|
||||||
try:
|
try:
|
||||||
id = request.args.get('id')
|
discussion_id = request.args.get("id")
|
||||||
new_message = request.args.get('message')
|
new_message = request.args.get("message")
|
||||||
self.current_discussion.update_message(id, new_message)
|
self.current_discussion.update_message(discussion_id, new_message)
|
||||||
return jsonify({"status":'ok'})
|
return jsonify({"status": "ok"})
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
msg = traceback.print_exc()
|
return (
|
||||||
return "<b style='color:red;'>Exception :<b>"+str(ex)+"<br>"+traceback.format_exc()+"<br>Please report exception"
|
"<b style='color:red;'>Exception :<b>"
|
||||||
|
+ str(ex)
|
||||||
|
+ "<br>"
|
||||||
|
+ traceback.format_exc()
|
||||||
|
+ "<br>Please report exception"
|
||||||
|
)
|
||||||
|
|
||||||
def new_discussion(self):
|
def new_discussion(self):
|
||||||
title = request.args.get('title')
|
title = request.args.get("title")
|
||||||
self.current_discussion= Discussion.create_discussion(self.db_path, title)
|
self.current_discussion = Discussion.create_discussion(self.db_path, title)
|
||||||
# Get the current timestamp
|
# Get the current timestamp
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
# add a new discussion
|
# add a new discussion
|
||||||
self.chatbot_bindings.close()
|
# self.chatbot_bindings.close()
|
||||||
self.chatbot_bindings.open()
|
# self.chatbot_bindings.open()
|
||||||
|
|
||||||
# Return a success response
|
# Return a success response
|
||||||
return json.dumps({'id': self.current_discussion.discussion_id})
|
return json.dumps({"id": self.current_discussion.discussion_id, "time": timestamp})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Start the chatbot Flask app.')
|
parser = argparse.ArgumentParser(description="Start the chatbot Flask app.")
|
||||||
|
|
||||||
parser.add_argument('--temp', type=float, default=0.1, help='Temperature parameter for the model.')
|
parser.add_argument(
|
||||||
parser.add_argument('--n_predict', type=int, default=128, help='Number of tokens to predict at each step.')
|
"--temp", type=float, default=0.1, help="Temperature parameter for the model."
|
||||||
parser.add_argument('--top_k', type=int, default=40, help='Value for the top-k sampling.')
|
)
|
||||||
parser.add_argument('--top_p', type=float, default=0.95, help='Value for the top-p sampling.')
|
parser.add_argument(
|
||||||
parser.add_argument('--repeat_penalty', type=float, default=1.3, help='Penalty for repeated tokens.')
|
"--n_predict",
|
||||||
parser.add_argument('--repeat_last_n', type=int, default=64, help='Number of previous tokens to consider for the repeat penalty.')
|
type=int,
|
||||||
parser.add_argument('--ctx_size', type=int, default=2048, help='Size of the context window for the model.')
|
default=128,
|
||||||
parser.add_argument('--debug', dest='debug', action='store_true', help='launch Flask server in debug mode')
|
help="Number of tokens to predict at each step.",
|
||||||
parser.add_argument('--host', type=str, default='localhost', help='the hostname to listen on')
|
)
|
||||||
parser.add_argument('--port', type=int, default=9600, help='the port to listen on')
|
parser.add_argument(
|
||||||
parser.add_argument('--db_path', type=str, default='database.db', help='Database path')
|
"--top_k", type=int, default=40, help="Value for the top-k sampling."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_p", type=float, default=0.95, help="Value for the top-p sampling."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat_penalty", type=float, default=1.3, help="Penalty for repeated tokens."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat_last_n",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Number of previous tokens to consider for the repeat penalty.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctx_size",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Size of the context window for the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
dest="debug",
|
||||||
|
action="store_true",
|
||||||
|
help="launch Flask server in debug mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", type=str, default="localhost", help="the hostname to listen on"
|
||||||
|
)
|
||||||
|
parser.add_argument("--port", type=int, default=9600, help="the port to listen on")
|
||||||
|
parser.add_argument(
|
||||||
|
"--db_path", type=str, default="database.db", help="Database path"
|
||||||
|
)
|
||||||
parser.set_defaults(debug=False)
|
parser.set_defaults(debug=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
chatbot_bindings = Model(ggml_model='./models/gpt4all-lora-quantized-ggml.bin', n_ctx=512)
|
chatbot_model_bindings = Model(
|
||||||
|
ggml_model="./models/gpt4all-lora-quantized-ggml.bin", n_ctx=512
|
||||||
|
)
|
||||||
|
|
||||||
# Old Code
|
# Old Code
|
||||||
# GPT4All(decoder_config = {
|
# GPT4All(decoder_config = {
|
||||||
# 'temp': args.temp,
|
# 'temp': args.temp,
|
||||||
@ -346,7 +469,7 @@ if __name__ == '__main__':
|
|||||||
# 'ctx_size': args.ctx_size
|
# 'ctx_size': args.ctx_size
|
||||||
# })
|
# })
|
||||||
check_discussion_db(args.db_path)
|
check_discussion_db(args.db_path)
|
||||||
bot = Gpt4AllWebUI(chatbot_bindings, app, args.db_path)
|
bot = Gpt4AllWebUI(chatbot_model_bindings, app, args.db_path)
|
||||||
|
|
||||||
if args.debug:
|
if args.debug:
|
||||||
app.run(debug=True, host=args.host, port=args.port)
|
app.run(debug=True, host=args.host, port=args.port)
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app import app
|
from app import app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
with app.test_client() as client:
|
with app.test_client() as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
def test_homepage(client):
|
def test_homepage(client):
|
||||||
response = client.get('/')
|
response = client.get("/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert b"Welcome to my Flask app" in response.data
|
assert b"Welcome to my Flask app" in response.data
|
||||||
|
Loading…
Reference in New Issue
Block a user