Faster generation, can stop generation while it is generating

This commit is contained in:
ParisNeo 2023-04-30 03:15:11 +02:00
parent f98f115e6a
commit 5571cd0057
6 changed files with 73 additions and 65 deletions

40
app.py
View File

@ -35,8 +35,11 @@ from flask import (
from flask_socketio import SocketIO, emit from flask_socketio import SocketIO, emit
from pathlib import Path from pathlib import Path
import gc import gc
from geventwebsocket.handler import WebSocketHandler
from gevent.pywsgi import WSGIServer
app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static") app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
socketio = SocketIO(app) socketio = SocketIO(app, async_mode='gevent')
app.config['SECRET_KEY'] = 'secret!' app.config['SECRET_KEY'] = 'secret!'
# Set the logging level to WARNING or higher # Set the logging level to WARNING or higher
logging.getLogger('socketio').setLevel(logging.WARNING) logging.getLogger('socketio').setLevel(logging.WARNING)
@ -314,15 +317,25 @@ class Gpt4AllWebUI(GPT4AllAPI):
self.generating = True self.generating = True
# app.config['executor'] = ThreadPoolExecutor(max_workers=1) # app.config['executor'] = ThreadPoolExecutor(max_workers=1)
# app.config['executor'].submit(self.generate_message) # app.config['executor'].submit(self.generate_message)
print("## Generate message ##") print("## Generating message ##")
self.generate_message() self.generate_message()
print()
print("## Done ##") print("## Done ##")
print()
# Send final message
self.socketio.emit('final', {'data': self.bot_says})
self.current_discussion.update_message(response_id, self.bot_says) self.current_discussion.update_message(response_id, self.bot_says)
self.full_message_list.append(self.bot_says) self.full_message_list.append(self.bot_says)
self.cancel_gen = False self.cancel_gen = False
return bot_says return bot_says
else: else:
#No discussion available
print()
print("## Done ##") print("## Done ##")
print()
return "" return ""
@ -596,7 +609,26 @@ if __name__ == "__main__":
# app.config['executor'] = executor # app.config['executor'] = executor
bot = Gpt4AllWebUI(app, socketio, config, personality, config_file_path) bot = Gpt4AllWebUI(app, socketio, config, personality, config_file_path)
# chong Define custom WebSocketHandler with error handling
class CustomWebSocketHandler(WebSocketHandler):
def handle_error(self, environ, start_response, e):
# Handle the error here
print("WebSocket error:", e)
super().handle_error(environ, start_response, e)
# chong -add socket server
http_server = WSGIServer((config["host"], config["port"]), app, handler_class=CustomWebSocketHandler)
http_server = WSGIServer((config["host"], config["port"]), app, handler_class=WebSocketHandler)
if config["debug"]: if config["debug"]:
app.run(debug=True, host=config["host"], port=config["port"]) socketio.run(app,debug=True, host=config["host"], port=config["port"])
else: else:
app.run(host=config["host"], port=config["port"]) socketio.run(app, host=config["host"], port=config["port"])
# if config["debug"]:
# app.run(debug=True, host=config["host"], port=config["port"])
# else:
# app.run(host=config["host"], port=config["port"])

View File

@ -30,7 +30,8 @@ class LLAMACPP(GPTBackend):
super().__init__(config, False) super().__init__(config, False)
self.model = Model( self.model = Model(
ggml_model=f"./models/llama_cpp/{self.config['model']}", ggml_model=f"./models/llama_cpp/{self.config['model']}",
prompt_context="", prompt_prefix="", prompt_suffix="", anti_prompts= [],
n_ctx=self.config['ctx_size'], n_ctx=self.config['ctx_size'],
seed=self.config['seed'], seed=self.config['seed'],
) )
@ -53,17 +54,17 @@ class LLAMACPP(GPTBackend):
verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False. verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False.
""" """
try: try:
self.model.generate( self.model.reset()
prompt, for tok in self.model.generate(prompt,
new_text_callback=new_text_callback, n_predict=n_predict,
n_predict=n_predict, temp=self.config['temp'],
temp=self.config['temp'], top_k=self.config['top_k'],
top_k=self.config['top_k'], top_p=self.config['top_p'],
top_p=self.config['top_p'], repeat_penalty=self.config['repeat_penalty'],
repeat_penalty=self.config['repeat_penalty'], repeat_last_n = self.config['repeat_last_n'],
repeat_last_n = self.config['repeat_last_n'], n_threads=self.config['n_threads'],
n_threads=self.config['n_threads'], ):
verbose=verbose if not new_text_callback(tok):
) return
except Exception as ex: except Exception as ex:
print(ex) print(ex)

View File

@ -30,13 +30,13 @@ personality_conditionning: |
welcome_message: "Welcome! I am GPT4All A free and open assistant. What can I do for you today?" welcome_message: "Welcome! I am GPT4All A free and open assistant. What can I do for you today?"
# This prefix is added at the beginning of any message input by the user # This prefix is added at the beginning of any message input by the user
user_message_prefix: "### Human:" user_message_prefix: "### Human:\n"
# A text to put between user and chatbot messages # A text to put between user and chatbot messages
link_text: "\n" link_text: "\n"
# This prefix is added at the beginning of any message output by the ai # This prefix is added at the beginning of any message output by the ai
ai_message_prefix: "### Assistant:" ai_message_prefix: "### Assistant:\n"
# Here is the list of extensions this personality requires # Here is the list of extensions this personality requires
dependencies: [] dependencies: []

View File

@ -196,45 +196,18 @@ class GPT4AllAPI():
print(text, end="") print(text, end="")
sys.stdout.flush() sys.stdout.flush()
if self.chatbot_bindings.inline: self.bot_says += text
self.bot_says += text if not self.personality["user_message_prefix"].strip().lower() in self.bot_says.lower():
if not self.personality["user_message_prefix"].lower() in self.bot_says.lower(): self.socketio.emit('message', {'data': self.bot_says});
self.socketio.emit('message', {'data': self.bot_says}); if self.cancel_gen:
if self.cancel_gen: print("Generation canceled")
print("Generation canceled")
return False
else:
return True
else:
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].lower())
print("The model is halucinating")
self.socketio.emit('final', {'data': self.bot_says})
return False return False
else:
self.full_text += text
if self.is_bot_text_started:
self.bot_says += text
if not self.personality["user_message_prefix"].lower() in self.bot_says.lower():
self.socketio.emit('message', {'data': self.bot_says});
#self.socketio.emit('message', {'data': text});
if self.cancel_gen:
print("Generation canceled")
self.socketio.emit('final', {'data': self.bot_says})
return False
else:
return True
else:
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].lower())
print("The model is halucinating")
self.socketio.emit('final', {'data': self.bot_says})
self.cancel_gen=True
return False
#if self.current_message in self.full_text:
if len(self.discussion_messages) < len(self.full_text):
self.is_bot_text_started = True
else: else:
self.socketio.emit('waiter', {'wait': (len(self.discussion_messages)-len(self.full_text))/len(self.discussion_messages)}); return True
else:
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].strip())
print("The model is halucinating")
return False
def generate_message(self): def generate_message(self):
self.generating=True self.generating=True

View File

@ -4,11 +4,13 @@ nomic
pytest pytest
pyyaml pyyaml
markdown markdown
pyllamacpp==1.0.7 pyllamacpp==2.0.0
gpt4all-j==0.2.1 gpt4all-j==0.2.1
--find-links https://download.pytorch.org/whl/cu117 --find-links https://download.pytorch.org/whl/cu117
torch==2.0.0 torch==2.0.0
torchvision torchvision
torchaudio torchaudio
transformers transformers
accelerate accelerate
gevent
gevent-websocket

View File

@ -13,9 +13,6 @@ function send_message(service_name, parameters){
globals.socket = socket globals.socket = socket
globals.is_generating = false globals.is_generating = false
socket.on('connect', function() { socket.on('connect', function() {
globals.sendbtn.style.display="block";
globals.waitAnimation.style.display="none";
globals.stopGeneration.style.display = "none";
entry_counter = 0; entry_counter = 0;
if(globals.is_generating){ if(globals.is_generating){
globals.socket.disconnect() globals.socket.disconnect()
@ -29,6 +26,10 @@ function send_message(service_name, parameters){
socket.on('disconnect', function() { socket.on('disconnect', function() {
console.log("disconnected") console.log("disconnected")
entry_counter = 0; entry_counter = 0;
console.log("Disconnected")
globals.sendbtn.style.display="block";
globals.waitAnimation.style.display="none";
globals.stopGeneration.style.display = "none";
}); });
@ -40,10 +41,11 @@ function send_message(service_name, parameters){
} }
globals.bot_msg.setSender(msg.bot); globals.bot_msg.setSender(msg.bot);
globals.bot_msg.setID(msg.response_id); globals.bot_msg.setID(msg.response_id);
globals.bot_msg.messageTextElement.innerHTML = `Generating answer. Please satnd by...`;
}); });
socket.on('waiter', function(msg) { socket.on('waiter', function(msg) {
globals.bot_msg.messageTextElement.innerHTML = `Remaining words ${Math.floor(msg.wait * 100)}%`; globals.bot_msg.messageTextElement.innerHTML = `Generating answer. Please satnd by...`;
}); });
socket.on('message', function(msg) { socket.on('message', function(msg) {
@ -59,10 +61,8 @@ function send_message(service_name, parameters){
text = msg.data; text = msg.data;
globals.bot_msg.hiddenElement.innerHTML = text globals.bot_msg.hiddenElement.innerHTML = text
globals.bot_msg.messageTextElement.innerHTML = marked.marked(text) globals.bot_msg.messageTextElement.innerHTML = marked.marked(text)
socket.disconnect()
globals.sendbtn.style.display="block";
globals.waitAnimation.style.display="none";
globals.stopGeneration.style.display = "none";
}); });
} }