Faster generation, can stop generation while it is generating

This commit is contained in:
ParisNeo 2023-04-30 03:15:11 +02:00
parent f98f115e6a
commit 5571cd0057
6 changed files with 73 additions and 65 deletions

40
app.py
View File

@ -35,8 +35,11 @@ from flask import (
from flask_socketio import SocketIO, emit
from pathlib import Path
import gc
from geventwebsocket.handler import WebSocketHandler
from gevent.pywsgi import WSGIServer
app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
socketio = SocketIO(app)
socketio = SocketIO(app, async_mode='gevent')
app.config['SECRET_KEY'] = 'secret!'
# Set the logging level to WARNING or higher
logging.getLogger('socketio').setLevel(logging.WARNING)
@ -314,15 +317,25 @@ class Gpt4AllWebUI(GPT4AllAPI):
self.generating = True
# app.config['executor'] = ThreadPoolExecutor(max_workers=1)
# app.config['executor'].submit(self.generate_message)
print("## Generate message ##")
print("## Generating message ##")
self.generate_message()
print()
print("## Done ##")
print()
# Send final message
self.socketio.emit('final', {'data': self.bot_says})
self.current_discussion.update_message(response_id, self.bot_says)
self.full_message_list.append(self.bot_says)
self.cancel_gen = False
return bot_says
else:
#No discussion available
print()
print("## Done ##")
print()
return ""
@ -596,7 +609,26 @@ if __name__ == "__main__":
# app.config['executor'] = executor
bot = Gpt4AllWebUI(app, socketio, config, personality, config_file_path)
# chong Define custom WebSocketHandler with error handling
class CustomWebSocketHandler(WebSocketHandler):
def handle_error(self, environ, start_response, e):
# Handle the error here
print("WebSocket error:", e)
super().handle_error(environ, start_response, e)
# chong -add socket server
http_server = WSGIServer((config["host"], config["port"]), app, handler_class=CustomWebSocketHandler)
http_server = WSGIServer((config["host"], config["port"]), app, handler_class=WebSocketHandler)
if config["debug"]:
app.run(debug=True, host=config["host"], port=config["port"])
socketio.run(app,debug=True, host=config["host"], port=config["port"])
else:
app.run(host=config["host"], port=config["port"])
socketio.run(app, host=config["host"], port=config["port"])
# if config["debug"]:
# app.run(debug=True, host=config["host"], port=config["port"])
# else:
# app.run(host=config["host"], port=config["port"])

View File

@ -31,6 +31,7 @@ class LLAMACPP(GPTBackend):
self.model = Model(
ggml_model=f"./models/llama_cpp/{self.config['model']}",
prompt_context="", prompt_prefix="", prompt_suffix="", anti_prompts= [],
n_ctx=self.config['ctx_size'],
seed=self.config['seed'],
)
@ -53,9 +54,8 @@ class LLAMACPP(GPTBackend):
verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False.
"""
try:
self.model.generate(
prompt,
new_text_callback=new_text_callback,
self.model.reset()
for tok in self.model.generate(prompt,
n_predict=n_predict,
temp=self.config['temp'],
top_k=self.config['top_k'],
@ -63,7 +63,8 @@ class LLAMACPP(GPTBackend):
repeat_penalty=self.config['repeat_penalty'],
repeat_last_n = self.config['repeat_last_n'],
n_threads=self.config['n_threads'],
verbose=verbose
)
):
if not new_text_callback(tok):
return
except Exception as ex:
print(ex)

View File

@ -30,13 +30,13 @@ personality_conditionning: |
welcome_message: "Welcome! I am GPT4All A free and open assistant. What can I do for you today?"
# This prefix is added at the beginning of any message input by the user
user_message_prefix: "### Human:"
user_message_prefix: "### Human:\n"
# A text to put between user and chatbot messages
link_text: "\n"
# This prefix is added at the beginning of any message output by the ai
ai_message_prefix: "### Assistant:"
ai_message_prefix: "### Assistant:\n"
# Here is the list of extensions this personality requires
dependencies: []

View File

@ -196,9 +196,8 @@ class GPT4AllAPI():
print(text, end="")
sys.stdout.flush()
if self.chatbot_bindings.inline:
self.bot_says += text
if not self.personality["user_message_prefix"].lower() in self.bot_says.lower():
if not self.personality["user_message_prefix"].strip().lower() in self.bot_says.lower():
self.socketio.emit('message', {'data': self.bot_says});
if self.cancel_gen:
print("Generation canceled")
@ -206,35 +205,9 @@ class GPT4AllAPI():
else:
return True
else:
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].lower())
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].strip())
print("The model is halucinating")
self.socketio.emit('final', {'data': self.bot_says})
return False
else:
self.full_text += text
if self.is_bot_text_started:
self.bot_says += text
if not self.personality["user_message_prefix"].lower() in self.bot_says.lower():
self.socketio.emit('message', {'data': self.bot_says});
#self.socketio.emit('message', {'data': text});
if self.cancel_gen:
print("Generation canceled")
self.socketio.emit('final', {'data': self.bot_says})
return False
else:
return True
else:
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].lower())
print("The model is halucinating")
self.socketio.emit('final', {'data': self.bot_says})
self.cancel_gen=True
return False
#if self.current_message in self.full_text:
if len(self.discussion_messages) < len(self.full_text):
self.is_bot_text_started = True
else:
self.socketio.emit('waiter', {'wait': (len(self.discussion_messages)-len(self.full_text))/len(self.discussion_messages)});
def generate_message(self):
self.generating=True

View File

@ -4,7 +4,7 @@ nomic
pytest
pyyaml
markdown
pyllamacpp==1.0.7
pyllamacpp==2.0.0
gpt4all-j==0.2.1
--find-links https://download.pytorch.org/whl/cu117
torch==2.0.0
@ -12,3 +12,5 @@ torchvision
torchaudio
transformers
accelerate
gevent
gevent-websocket

View File

@ -13,9 +13,6 @@ function send_message(service_name, parameters){
globals.socket = socket
globals.is_generating = false
socket.on('connect', function() {
globals.sendbtn.style.display="block";
globals.waitAnimation.style.display="none";
globals.stopGeneration.style.display = "none";
entry_counter = 0;
if(globals.is_generating){
globals.socket.disconnect()
@ -29,6 +26,10 @@ function send_message(service_name, parameters){
socket.on('disconnect', function() {
console.log("disconnected")
entry_counter = 0;
console.log("Disconnected")
globals.sendbtn.style.display="block";
globals.waitAnimation.style.display="none";
globals.stopGeneration.style.display = "none";
});
@ -40,10 +41,11 @@ function send_message(service_name, parameters){
}
globals.bot_msg.setSender(msg.bot);
globals.bot_msg.setID(msg.response_id);
globals.bot_msg.messageTextElement.innerHTML = `Generating answer. Please satnd by...`;
});
socket.on('waiter', function(msg) {
globals.bot_msg.messageTextElement.innerHTML = `Remaining words ${Math.floor(msg.wait * 100)}%`;
globals.bot_msg.messageTextElement.innerHTML = `Generating answer. Please satnd by...`;
});
socket.on('message', function(msg) {
@ -59,10 +61,8 @@ function send_message(service_name, parameters){
text = msg.data;
globals.bot_msg.hiddenElement.innerHTML = text
globals.bot_msg.messageTextElement.innerHTML = marked.marked(text)
socket.disconnect()
globals.sendbtn.style.display="block";
globals.waitAnimation.style.display="none";
globals.stopGeneration.style.display = "none";
});
}