2023-04-06 19:12:49 +00:00
import argparse
2023-04-07 16:58:42 +00:00
import json
2023-04-06 19:12:49 +00:00
import re
import sqlite3
2023-04-07 16:58:42 +00:00
import traceback
2023-04-06 19:12:49 +00:00
from datetime import datetime
2023-04-07 16:58:42 +00:00
from flask import (
Flask ,
Response ,
jsonify ,
render_template ,
request ,
stream_with_context ,
)
from pyllamacpp . model import Model
2023-04-06 19:12:49 +00:00
2023-04-07 16:58:42 +00:00
# =================================== Database ==================================================================
2023-04-06 19:12:49 +00:00
class Discussion :
2023-04-07 16:58:42 +00:00
def __init__ ( self , discussion_id , db_path = " database.db " ) :
2023-04-06 19:12:49 +00:00
self . discussion_id = discussion_id
self . db_path = db_path
@staticmethod
2023-04-07 16:58:42 +00:00
def create_discussion ( db_path = " database.db " , title = " untitled " ) :
2023-04-06 19:12:49 +00:00
with sqlite3 . connect ( db_path ) as conn :
cur = conn . cursor ( )
cur . execute ( " INSERT INTO discussion (title) VALUES (?) " , ( title , ) )
discussion_id = cur . lastrowid
conn . commit ( )
return Discussion ( discussion_id , db_path )
2023-04-07 16:58:42 +00:00
2023-04-06 19:12:49 +00:00
@staticmethod
2023-04-07 16:58:42 +00:00
def get_discussion ( db_path = " database.db " , discussion_id = 0 ) :
return Discussion ( discussion_id , db_path )
2023-04-06 19:12:49 +00:00
def add_message ( self , sender , content ) :
with sqlite3 . connect ( self . db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute (
" INSERT INTO message (sender, content, discussion_id) VALUES (?, ?, ?) " ,
( sender , content , self . discussion_id ) ,
)
2023-04-06 19:12:49 +00:00
message_id = cur . lastrowid
conn . commit ( )
return message_id
2023-04-07 16:58:42 +00:00
2023-04-06 19:12:49 +00:00
@staticmethod
def get_discussions ( db_path ) :
with sqlite3 . connect ( db_path ) as conn :
cursor = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cursor . execute ( " SELECT * FROM discussion " )
2023-04-06 19:12:49 +00:00
rows = cursor . fetchall ( )
2023-04-07 16:58:42 +00:00
return [ { " id " : row [ 0 ] , " title " : row [ 1 ] } for row in rows ]
2023-04-06 19:12:49 +00:00
@staticmethod
def rename ( db_path , discussion_id , title ) :
with sqlite3 . connect ( db_path ) as conn :
cursor = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cursor . execute (
" UPDATE discussion SET title=? WHERE id=? " , ( title , discussion_id )
)
2023-04-06 19:12:49 +00:00
conn . commit ( )
def delete_discussion ( self ) :
with sqlite3 . connect ( self . db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute (
" DELETE FROM message WHERE discussion_id=? " , ( self . discussion_id , )
)
cur . execute ( " DELETE FROM discussion WHERE id=? " , ( self . discussion_id , ) )
2023-04-06 19:12:49 +00:00
conn . commit ( )
def get_messages ( self ) :
with sqlite3 . connect ( self . db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute (
" SELECT * FROM message WHERE discussion_id=? " , ( self . discussion_id , )
)
2023-04-06 19:12:49 +00:00
rows = cur . fetchall ( )
2023-04-07 16:58:42 +00:00
return [ { " sender " : row [ 1 ] , " content " : row [ 2 ] , " id " : row [ 0 ] } for row in rows ]
2023-04-06 19:12:49 +00:00
def update_message ( self , message_id , new_content ) :
with sqlite3 . connect ( self . db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute (
" UPDATE message SET content = ? WHERE id = ? " , ( new_content , message_id )
)
2023-04-06 19:12:49 +00:00
conn . commit ( )
def remove_discussion ( self ) :
with sqlite3 . connect ( self . db_path ) as conn :
2023-04-07 16:58:42 +00:00
conn . cursor ( ) . execute (
" DELETE FROM discussion WHERE id=? " , ( self . discussion_id , )
)
2023-04-06 19:12:49 +00:00
conn . commit ( )
2023-04-07 16:58:42 +00:00
def last_discussion_has_messages ( db_path = " database.db " ) :
2023-04-06 19:12:49 +00:00
with sqlite3 . connect ( db_path ) as conn :
2023-04-07 16:58:42 +00:00
cursor = conn . cursor ( )
cursor . execute ( " SELECT * FROM message ORDER BY id DESC LIMIT 1 " )
last_message = cursor . fetchone ( )
2023-04-06 19:12:49 +00:00
return last_message is not None
2023-04-07 16:58:42 +00:00
def export_to_json ( db_path = " database.db " ) :
2023-04-06 19:12:49 +00:00
with sqlite3 . connect ( db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute ( " SELECT * FROM discussion " )
2023-04-06 19:12:49 +00:00
discussions = [ ]
for row in cur . fetchall ( ) :
discussion_id = row [ 0 ]
2023-04-07 16:58:42 +00:00
discussion = { " id " : discussion_id , " messages " : [ ] }
cur . execute ( " SELECT * FROM message WHERE discussion_id=? " , ( discussion_id , ) )
2023-04-06 19:12:49 +00:00
for message_row in cur . fetchall ( ) :
2023-04-07 16:58:42 +00:00
discussion [ " messages " ] . append (
{ " sender " : message_row [ 1 ] , " content " : message_row [ 2 ] }
)
2023-04-06 19:12:49 +00:00
discussions . append ( discussion )
return discussions
2023-04-07 16:58:42 +00:00
def remove_discussions ( db_path = " database.db " ) :
2023-04-06 19:12:49 +00:00
with sqlite3 . connect ( db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute ( " DELETE FROM message " )
cur . execute ( " DELETE FROM discussion " )
2023-04-06 19:12:49 +00:00
conn . commit ( )
2023-04-07 16:58:42 +00:00
2023-04-06 19:12:49 +00:00
# create database schema
def check_discussion_db ( db_path ) :
print ( " Checking discussions database... " )
with sqlite3 . connect ( db_path ) as conn :
cur = conn . cursor ( )
2023-04-07 16:58:42 +00:00
cur . execute (
"""
2023-04-06 19:12:49 +00:00
CREATE TABLE IF NOT EXISTS discussion (
id INTEGER PRIMARY KEY AUTOINCREMENT ,
title TEXT
)
2023-04-07 16:58:42 +00:00
"""
)
cur . execute (
"""
2023-04-06 19:12:49 +00:00
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY AUTOINCREMENT ,
sender TEXT NOT NULL ,
content TEXT NOT NULL ,
discussion_id INTEGER NOT NULL ,
FOREIGN KEY ( discussion_id ) REFERENCES discussion ( id )
)
2023-04-07 16:58:42 +00:00
"""
)
2023-04-06 19:12:49 +00:00
conn . commit ( )
print ( " Ok " )
2023-04-07 16:58:42 +00:00
2023-04-06 19:12:49 +00:00
# ========================================================================================================================
2023-04-07 16:58:42 +00:00
app = Flask ( " GPT4All-WebUI " , static_url_path = " /static " , static_folder = " static " )
2023-04-06 19:12:49 +00:00
2023-04-07 16:58:42 +00:00
class Gpt4AllWebUI :
def __init__ ( self , chatbot_bindings , _app , db_path = " database.db " ) - > None :
2023-04-06 19:12:49 +00:00
self . current_discussion = None
self . chatbot_bindings = chatbot_bindings
2023-04-07 16:58:42 +00:00
self . app = _app
self . db_path = db_path
self . add_endpoint ( " / " , " " , self . index , methods = [ " GET " ] )
self . add_endpoint ( " /export " , " export " , self . export , methods = [ " GET " ] )
self . add_endpoint (
" /new_discussion " , " new_discussion " , self . new_discussion , methods = [ " GET " ]
)
self . add_endpoint ( " /bot " , " bot " , self . bot , methods = [ " POST " ] )
self . add_endpoint (
" /discussions " , " discussions " , self . discussions , methods = [ " GET " ]
)
self . add_endpoint ( " /rename " , " rename " , self . rename , methods = [ " POST " ] )
self . add_endpoint (
" /get_messages " , " get_messages " , self . get_messages , methods = [ " POST " ]
)
self . add_endpoint (
" /delete_discussion " ,
" delete_discussion " ,
self . delete_discussion ,
methods = [ " POST " ] ,
)
self . add_endpoint (
" /update_message " , " update_message " , self . update_message , methods = [ " GET " ]
)
conditionning_message = """
2023-04-06 21:15:01 +00:00
Instruction : Act as GPT4All . A kind and helpful AI bot built to help users solve problems .
Start by welcoming the user then stop sending text .
GPT4All : """
self . prepare_query ( conditionning_message )
2023-04-07 16:58:42 +00:00
chatbot_bindings . generate (
conditionning_message ,
n_predict = 55 ,
new_text_callback = self . new_text_callback ,
n_threads = 8 ,
)
print ( f " Bot said: { self . bot_says } " )
2023-04-06 19:12:49 +00:00
# Chatbot conditionning
# response = self.chatbot_bindings.prompt("This is a discussion between A user and an AI. AI responds to user questions in a helpful manner. AI is not allowed to lie or deceive. AI welcomes the user\n### Response:")
# print(response)
2023-04-06 21:15:01 +00:00
def prepare_query ( self , message ) :
2023-04-07 16:58:42 +00:00
self . bot_says = " "
self . full_text = " "
self . is_bot_text_started = False
2023-04-06 21:15:01 +00:00
self . current_message = message
def new_text_callback ( self , text : str ) :
print ( text , end = " " )
self . full_text + = text
if self . is_bot_text_started :
self . bot_says + = text
if self . current_message in self . full_text :
2023-04-07 16:58:42 +00:00
self . is_bot_text_started = True
2023-04-06 21:15:01 +00:00
def new_text_callback_with_yield ( self , text : str ) :
"""
To do , fix the problem with yield to be able to show interactive response as text comes
"""
print ( text , end = " " )
self . full_text + = text
if self . is_bot_text_started :
self . bot_says + = text
if self . current_message in self . full_text :
2023-04-07 16:58:42 +00:00
self . is_bot_text_started = True
2023-04-06 21:15:01 +00:00
yield text
2023-04-07 16:58:42 +00:00
def add_endpoint (
self ,
endpoint = None ,
endpoint_name = None ,
handler = None ,
methods = [ " GET " ] ,
* args ,
* * kwargs ,
) :
self . app . add_url_rule (
endpoint , endpoint_name , handler , methods = methods , * args , * * kwargs
)
2023-04-06 19:12:49 +00:00
def index ( self ) :
2023-04-07 16:58:42 +00:00
return render_template ( " chat.html " )
2023-04-06 19:12:49 +00:00
def format_message ( self , message ) :
# Look for a code block within the message
pattern = re . compile ( r " (```.*?```) " , re . DOTALL )
match = pattern . search ( message )
# If a code block is found, replace it with a <code> tag
if match :
code_block = match . group ( 1 )
message = message . replace ( code_block , f " <code> { code_block [ 3 : - 3 ] } </code> " )
# Return the formatted message
return message
def export ( self ) :
return jsonify ( export_to_json ( self . db_path ) )
@stream_with_context
def parse_to_prompt_stream ( self , message , message_id ) :
2023-04-07 16:58:42 +00:00
bot_says = " "
self . stop = False
2023-04-06 19:12:49 +00:00
# send the message to the bot
print ( f " Received message : { message } " )
# First we need to send the new message ID to the client
2023-04-07 16:58:42 +00:00
response_id = self . current_discussion . add_message (
" GPT4All " , " "
) # first the content is empty, but we'll fill it at the end
yield (
json . dumps (
{
" type " : " input_message_infos " ,
" message " : message ,
" id " : message_id ,
" response_id " : response_id ,
}
)
)
2023-04-06 19:12:49 +00:00
2023-04-07 16:58:42 +00:00
self . current_message = " User: " + message + " \n GPT4All: "
2023-04-06 21:15:01 +00:00
self . prepare_query ( self . current_message )
2023-04-07 16:58:42 +00:00
chatbot_model_bindings . generate (
self . current_message ,
n_predict = 55 ,
new_text_callback = self . new_text_callback ,
n_threads = 8 ,
)
self . current_discussion . update_message ( response_id , self . bot_says )
2023-04-06 21:15:01 +00:00
yield self . bot_says
# TODO : change this to use the yield version in order to send text word by word
return " \n " . join ( bot_says )
2023-04-07 16:58:42 +00:00
2023-04-06 19:12:49 +00:00
def bot ( self ) :
2023-04-07 16:58:42 +00:00
self . stop = True
try :
if self . current_discussion is None or not last_discussion_has_messages (
self . db_path
) :
self . current_discussion = Discussion . create_discussion ( self . db_path )
message_id = self . current_discussion . add_message (
" user " , request . json [ " message " ]
)
message = f " { request . json [ ' message ' ] } "
# Segmented (the user receives the output as it comes)
# We will first send a json entry that contains the message id and so on, then the text as it goes
return Response (
stream_with_context (
self . parse_to_prompt_stream ( message , message_id )
)
)
except Exception as ex :
print ( ex )
return (
" <b style= ' color:red; ' >Exception :<b> "
+ str ( ex )
+ " <br> "
+ traceback . format_exc ( )
+ " <br>Please report exception "
)
2023-04-06 19:12:49 +00:00
def discussions ( self ) :
try :
2023-04-07 16:58:42 +00:00
discussions = Discussion . get_discussions ( self . db_path )
2023-04-06 19:12:49 +00:00
return jsonify ( discussions )
except Exception as ex :
print ( ex )
2023-04-07 16:58:42 +00:00
return (
" <b style= ' color:red; ' >Exception :<b> "
+ str ( ex )
+ " <br> "
+ traceback . format_exc ( )
+ " <br>Please report exception "
)
2023-04-06 19:12:49 +00:00
def rename ( self ) :
data = request . get_json ( )
2023-04-07 16:58:42 +00:00
discussion_id = data [ " id " ]
title = data [ " title " ]
Discussion . rename ( self . db_path , discussion_id , title )
2023-04-06 19:12:49 +00:00
return " renamed successfully "
def get_messages ( self ) :
data = request . get_json ( )
2023-04-07 16:58:42 +00:00
discussion_id = data [ " id " ]
self . current_discussion = Discussion ( discussion_id , self . db_path )
2023-04-06 19:12:49 +00:00
messages = self . current_discussion . get_messages ( )
return jsonify ( messages )
def delete_discussion ( self ) :
data = request . get_json ( )
2023-04-07 16:58:42 +00:00
discussion_id = data [ " id " ]
self . current_discussion = Discussion ( discussion_id , self . db_path )
2023-04-06 19:12:49 +00:00
self . current_discussion . delete_discussion ( )
self . current_discussion = None
return jsonify ( { } )
2023-04-07 16:58:42 +00:00
2023-04-06 19:12:49 +00:00
def update_message ( self ) :
try :
2023-04-07 16:58:42 +00:00
discussion_id = request . args . get ( " id " )
new_message = request . args . get ( " message " )
self . current_discussion . update_message ( discussion_id , new_message )
return jsonify ( { " status " : " ok " } )
2023-04-06 19:12:49 +00:00
except Exception as ex :
print ( ex )
2023-04-07 16:58:42 +00:00
return (
" <b style= ' color:red; ' >Exception :<b> "
+ str ( ex )
+ " <br> "
+ traceback . format_exc ( )
+ " <br>Please report exception "
)
2023-04-06 19:12:49 +00:00
def new_discussion ( self ) :
2023-04-07 16:58:42 +00:00
title = request . args . get ( " title " )
self . current_discussion = Discussion . create_discussion ( self . db_path , title )
2023-04-06 19:12:49 +00:00
# Get the current timestamp
timestamp = datetime . now ( ) . strftime ( " % Y- % m- %d % H: % M: % S " )
# add a new discussion
2023-04-07 16:58:42 +00:00
# self.chatbot_bindings.close()
# self.chatbot_bindings.open()
2023-04-06 19:12:49 +00:00
# Return a success response
2023-04-07 16:58:42 +00:00
return json . dumps ( { " id " : self . current_discussion . discussion_id , " time " : timestamp } )
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( description = " Start the chatbot Flask app. " )
parser . add_argument (
" --temp " , type = float , default = 0.1 , help = " Temperature parameter for the model. "
)
parser . add_argument (
" --n_predict " ,
type = int ,
default = 128 ,
help = " Number of tokens to predict at each step. " ,
)
parser . add_argument (
" --top_k " , type = int , default = 40 , help = " Value for the top-k sampling. "
)
parser . add_argument (
" --top_p " , type = float , default = 0.95 , help = " Value for the top-p sampling. "
)
parser . add_argument (
" --repeat_penalty " , type = float , default = 1.3 , help = " Penalty for repeated tokens. "
)
parser . add_argument (
" --repeat_last_n " ,
type = int ,
default = 64 ,
help = " Number of previous tokens to consider for the repeat penalty. " ,
)
parser . add_argument (
" --ctx_size " ,
type = int ,
default = 2048 ,
help = " Size of the context window for the model. " ,
)
parser . add_argument (
" --debug " ,
dest = " debug " ,
action = " store_true " ,
help = " launch Flask server in debug mode " ,
)
parser . add_argument (
" --host " , type = str , default = " localhost " , help = " the hostname to listen on "
)
parser . add_argument ( " --port " , type = int , default = 9600 , help = " the port to listen on " )
parser . add_argument (
" --db_path " , type = str , default = " database.db " , help = " Database path "
)
2023-04-06 19:12:49 +00:00
parser . set_defaults ( debug = False )
args = parser . parse_args ( )
2023-04-07 16:58:42 +00:00
chatbot_model_bindings = Model (
ggml_model = " ./models/gpt4all-lora-quantized-ggml.bin " , n_ctx = 512
)
2023-04-06 21:15:01 +00:00
# Old Code
# GPT4All(decoder_config = {
# 'temp': args.temp,
# 'n_predict':args.n_predict,
# 'top_k':args.top_k,
# 'top_p':args.top_p,
# #'color': True,#"## Instruction",
# 'repeat_penalty': args.repeat_penalty,
# 'repeat_last_n':args.repeat_last_n,
# 'ctx_size': args.ctx_size
# })
2023-04-06 19:12:49 +00:00
check_discussion_db ( args . db_path )
2023-04-07 16:58:42 +00:00
bot = Gpt4AllWebUI ( chatbot_model_bindings , app , args . db_path )
2023-04-06 19:12:49 +00:00
if args . debug :
app . run ( debug = True , host = args . host , port = args . port )
else :
app . run ( host = args . host , port = args . port )