Merge pull request #51 from ParisNeo/flask_sse

Multimodel is working
This commit is contained in:
Saifeddine ALOUI 2023-04-08 19:57:12 +02:00 committed by GitHub
commit 438e2be7cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 229 additions and 152 deletions

179
app.py
View File

@ -1,12 +1,11 @@
import argparse
import json
import re
import sqlite3
import traceback
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import sys
from db import Discussion, export_to_json, check_discussion_db, last_discussion_has_messages
from flask import (
Flask,
Response,
@ -14,155 +13,11 @@ from flask import (
render_template,
request,
stream_with_context,
send_from_directory
)
from pyllamacpp.model import Model
from queue import Queue
# =================================== Database ==================================================================
class Discussion:
def __init__(self, discussion_id, db_path="database.db"):
self.discussion_id = discussion_id
self.db_path = db_path
@staticmethod
def create_discussion(db_path="database.db", title="untitled"):
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute("INSERT INTO discussion (title) VALUES (?)", (title,))
discussion_id = cur.lastrowid
conn.commit()
return Discussion(discussion_id, db_path)
@staticmethod
def get_discussion(db_path="database.db", discussion_id=0):
return Discussion(discussion_id, db_path)
def add_message(self, sender, content):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)",
(sender, content, self.discussion_id),
)
message_id = cur.lastrowid
conn.commit()
return message_id
@staticmethod
def get_discussions(db_path):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM discussion")
rows = cursor.fetchall()
return [{"id": row[0], "title": row[1]} for row in rows]
@staticmethod
def rename(db_path, discussion_id, title):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE discussion SET title=? WHERE id=?", (title, discussion_id)
)
conn.commit()
def delete_discussion(self):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"DELETE FROM message WHERE discussion_id=?", (self.discussion_id,)
)
cur.execute("DELETE FROM discussion WHERE id=?", (self.discussion_id,))
conn.commit()
def get_messages(self):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"SELECT * FROM message WHERE discussion_id=?", (self.discussion_id,)
)
rows = cur.fetchall()
return [{"sender": row[1], "content": row[2], "id": row[0]} for row in rows]
def update_message(self, message_id, new_content):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"UPDATE message SET content = ? WHERE id = ?", (new_content, message_id)
)
conn.commit()
def remove_discussion(self):
with sqlite3.connect(self.db_path) as conn:
conn.cursor().execute(
"DELETE FROM discussion WHERE id=?", (self.discussion_id,)
)
conn.commit()
def last_discussion_has_messages(db_path="database.db"):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
last_message = cursor.fetchone()
return last_message is not None
def export_to_json(db_path="database.db"):
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute("SELECT * FROM discussion")
discussions = []
for row in cur.fetchall():
discussion_id = row[0]
discussion = {"id": discussion_id, "messages": []}
cur.execute("SELECT * FROM message WHERE discussion_id=?", (discussion_id,))
for message_row in cur.fetchall():
discussion["messages"].append(
{"sender": message_row[1], "content": message_row[2]}
)
discussions.append(discussion)
return discussions
def remove_discussions(db_path="database.db"):
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute("DELETE FROM message")
cur.execute("DELETE FROM discussion")
conn.commit()
# create database schema
def check_discussion_db(db_path):
print("Checking discussions database...")
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS discussion (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sender TEXT NOT NULL,
content TEXT NOT NULL,
discussion_id INTEGER NOT NULL,
FOREIGN KEY (discussion_id) REFERENCES discussion(id)
)
"""
)
conn.commit()
print("Ok")
# ========================================================================================================================
from pathlib import Path
app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
@ -174,7 +29,10 @@ class Gpt4AllWebUI:
self.app = _app
self.db_path = args.db_path
# This is the queue used to stream text to the ui as the bot spits out its response
self.text_queue = Queue(0)
self.add_endpoint("/", "", self.index, methods=["GET"])
self.add_endpoint("/export", "export", self.export, methods=["GET"])
@ -203,9 +61,21 @@ class Gpt4AllWebUI:
"/update_model_params", "update_model_params", self.update_model_params, methods=["POST"]
)
self.add_endpoint(
"/list_models", "list_models", self.list_models, methods=["GET"]
)
self.add_endpoint(
"/get_args", "get_args", self.get_args, methods=["GET"]
)
self.prepare_a_new_chatbot()
def list_models(self):
models_dir = Path('./models') # replace with the actual path to the models folder
models = [f.name for f in models_dir.glob('*.bin')]
return jsonify(models)
def prepare_a_new_chatbot(self):
# Create chatbot
@ -463,7 +333,12 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo
def update_model_params(self):
data = request.get_json()
self.args.model = str(data["model"])
model = str(data["model"])
if self.args.model != model:
print("New model selected")
self.args.model = model
self.prepare_a_new_chatbot()
self.args.n_predict = int(data["nPredict"])
self.args.seed = int(data["seed"])
@ -482,7 +357,11 @@ GPT4All:Welcome! I'm here to assist you with anything you need. What can I do fo
print(f"\trepeat_penalty:{self.args.repeat_penalty}")
print(f"\trepeat_last_n:{self.args.repeat_last_n}")
return jsonify({"status":"ok"})
def get_args(self):
return jsonify(self.args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Start the chatbot Flask app.")

146
db.py Normal file
View File

@ -0,0 +1,146 @@
import sqlite3
# =================================== Database ==================================================================
class Discussion:
def __init__(self, discussion_id, db_path="database.db"):
self.discussion_id = discussion_id
self.db_path = db_path
@staticmethod
def create_discussion(db_path="database.db", title="untitled"):
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute("INSERT INTO discussion (title) VALUES (?)", (title,))
discussion_id = cur.lastrowid
conn.commit()
return Discussion(discussion_id, db_path)
@staticmethod
def get_discussion(db_path="database.db", discussion_id=0):
return Discussion(discussion_id, db_path)
def add_message(self, sender, content):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?)",
(sender, content, self.discussion_id),
)
message_id = cur.lastrowid
conn.commit()
return message_id
@staticmethod
def get_discussions(db_path):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM discussion")
rows = cursor.fetchall()
return [{"id": row[0], "title": row[1]} for row in rows]
@staticmethod
def rename(db_path, discussion_id, title):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"UPDATE discussion SET title=? WHERE id=?", (title, discussion_id)
)
conn.commit()
def delete_discussion(self):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"DELETE FROM message WHERE discussion_id=?", (self.discussion_id,)
)
cur.execute("DELETE FROM discussion WHERE id=?", (self.discussion_id,))
conn.commit()
def get_messages(self):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"SELECT * FROM message WHERE discussion_id=?", (self.discussion_id,)
)
rows = cur.fetchall()
return [{"sender": row[1], "content": row[2], "id": row[0]} for row in rows]
def update_message(self, message_id, new_content):
with sqlite3.connect(self.db_path) as conn:
cur = conn.cursor()
cur.execute(
"UPDATE message SET content = ? WHERE id = ?", (new_content, message_id)
)
conn.commit()
def remove_discussion(self):
with sqlite3.connect(self.db_path) as conn:
conn.cursor().execute(
"DELETE FROM discussion WHERE id=?", (self.discussion_id,)
)
conn.commit()
def last_discussion_has_messages(db_path="database.db"):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM message ORDER BY id DESC LIMIT 1")
last_message = cursor.fetchone()
return last_message is not None
def export_to_json(db_path="database.db"):
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute("SELECT * FROM discussion")
discussions = []
for row in cur.fetchall():
discussion_id = row[0]
discussion = {"id": discussion_id, "messages": []}
cur.execute("SELECT * FROM message WHERE discussion_id=?", (discussion_id,))
for message_row in cur.fetchall():
discussion["messages"].append(
{"sender": message_row[1], "content": message_row[2]}
)
discussions.append(discussion)
return discussions
def remove_discussions(db_path="database.db"):
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute("DELETE FROM message")
cur.execute("DELETE FROM discussion")
conn.commit()
# create database schema
def check_discussion_db(db_path):
print("Checking discussions database...")
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS discussion (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sender TEXT NOT NULL,
content TEXT NOT NULL,
discussion_id INTEGER NOT NULL,
FOREIGN KEY (discussion_id) REFERENCES discussion(id)
)
"""
)
conn.commit()
print("Ok")
# ========================================================================================================================

View File

@ -171,11 +171,29 @@ goto :CONTINUE
:CONTINUE
echo.
set /p choice=Do you want to download and install the GPT4All model? [Y/N]
if /i ".choice." equ "Y" (
echo -n "Checking for git..."
if command -v git > /dev/null 2>&1; then
echo "OK"
else
read -p "Git is not installed. Would you like to install Git? [Y/N] " choice
if [ "$choice" = "Y" ] || [ "$choice" = "y" ]; then
echo "Installing Git..."
sudo apt update
sudo apt install -y git
else
echo "Please install Git and try again."
exit 1
fi
fi
echo Converting the model to the new format
if not exist tmp/llama.cpp git clone https://github.com/ggerganov/llama.cpp.git tmp\llama.cpp
move models\gpt4all-lora-quantized-ggml.bin models\gpt4all-lora-quantized-ggml.bin.original
python tmp\llama.cpp\migrate-ggml-2023-03-30-pr613.py models\gpt4all-lora-quantized-ggml.bin.original models\gpt4all-lora-quantized-ggml.bin
echo The model file (gpt4all-lora-quantized-ggml.bin) has been fixed.
)
echo Cleaning tmp folder

View File

@ -1,3 +1,36 @@
function populate_models(){
// Get a reference to the <select> element
const selectElement = document.getElementById('model');
// Fetch the list of .bin files from the models subfolder
fetch('/list_models')
.then(response => response.json())
.then(data => {
if (Array.isArray(data)) {
// data is an array
const selectElement = document.getElementById('model');
data.forEach(filename => {
const optionElement = document.createElement('option');
optionElement.value = filename;
optionElement.textContent = filename;
selectElement.appendChild(optionElement);
});
// fetch('/get_args')
// .then(response=> response.json())
// .then(data=>{
// })
} else {
console.error('Expected an array, but received:', data);
}
});
}
populate_models()
const submitButton = document.getElementById('submit-model-params');
submitButton.addEventListener('click', (event) => {
@ -5,7 +38,7 @@ submitButton.addEventListener('click', (event) => {
event.preventDefault();
modelInput = document.getElementById('model');
seedInput = document.getElementById('seed');
tempInput = document.getElementById('temp');
nPredictInput = document.getElementById('n-predict');

View File

@ -28,7 +28,8 @@
<form id="model-params-form" class="bg-white shadow-md rounded px-8 pt-6 pb-8 mb-4">
<div class="mb-4">
<label class="block text-gray-700 font-bold mb-2" for="model">Model</label>
<input class="bg-gray-700 shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline" id="model" type="text" name="model" value="gpt4all-lora-quantized.bin">
<select class="bg-gray-700 shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline" id="model" name="model" value="gpt4all-lora-quantized.bin">
</select>
</div>
<div class="mb-4">
<label class="block text-gray-700 font-bold mb-2" for="seed">Seed</label>