mirror of
https://github.com/ParisNeo/lollms-webui.git
synced 2024-12-19 04:17:52 +00:00
Faster generation, can stop generation while it is generating
This commit is contained in:
parent
f98f115e6a
commit
5571cd0057
40
app.py
40
app.py
@ -35,8 +35,11 @@ from flask import (
|
||||
from flask_socketio import SocketIO, emit
|
||||
from pathlib import Path
|
||||
import gc
|
||||
from geventwebsocket.handler import WebSocketHandler
|
||||
from gevent.pywsgi import WSGIServer
|
||||
|
||||
app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static")
|
||||
socketio = SocketIO(app)
|
||||
socketio = SocketIO(app, async_mode='gevent')
|
||||
app.config['SECRET_KEY'] = 'secret!'
|
||||
# Set the logging level to WARNING or higher
|
||||
logging.getLogger('socketio').setLevel(logging.WARNING)
|
||||
@ -314,15 +317,25 @@ class Gpt4AllWebUI(GPT4AllAPI):
|
||||
self.generating = True
|
||||
# app.config['executor'] = ThreadPoolExecutor(max_workers=1)
|
||||
# app.config['executor'].submit(self.generate_message)
|
||||
print("## Generate message ##")
|
||||
print("## Generating message ##")
|
||||
self.generate_message()
|
||||
|
||||
print()
|
||||
print("## Done ##")
|
||||
print()
|
||||
|
||||
# Send final message
|
||||
self.socketio.emit('final', {'data': self.bot_says})
|
||||
|
||||
self.current_discussion.update_message(response_id, self.bot_says)
|
||||
self.full_message_list.append(self.bot_says)
|
||||
self.cancel_gen = False
|
||||
return bot_says
|
||||
else:
|
||||
#No discussion available
|
||||
print()
|
||||
print("## Done ##")
|
||||
print()
|
||||
return ""
|
||||
|
||||
|
||||
@ -596,7 +609,26 @@ if __name__ == "__main__":
|
||||
# app.config['executor'] = executor
|
||||
bot = Gpt4AllWebUI(app, socketio, config, personality, config_file_path)
|
||||
|
||||
|
||||
# chong Define custom WebSocketHandler with error handling
|
||||
class CustomWebSocketHandler(WebSocketHandler):
|
||||
def handle_error(self, environ, start_response, e):
|
||||
# Handle the error here
|
||||
print("WebSocket error:", e)
|
||||
super().handle_error(environ, start_response, e)
|
||||
|
||||
# chong -add socket server
|
||||
http_server = WSGIServer((config["host"], config["port"]), app, handler_class=CustomWebSocketHandler)
|
||||
http_server = WSGIServer((config["host"], config["port"]), app, handler_class=WebSocketHandler)
|
||||
|
||||
if config["debug"]:
|
||||
app.run(debug=True, host=config["host"], port=config["port"])
|
||||
socketio.run(app,debug=True, host=config["host"], port=config["port"])
|
||||
else:
|
||||
app.run(host=config["host"], port=config["port"])
|
||||
socketio.run(app, host=config["host"], port=config["port"])
|
||||
|
||||
|
||||
|
||||
# if config["debug"]:
|
||||
# app.run(debug=True, host=config["host"], port=config["port"])
|
||||
# else:
|
||||
# app.run(host=config["host"], port=config["port"])
|
||||
|
@ -30,7 +30,8 @@ class LLAMACPP(GPTBackend):
|
||||
super().__init__(config, False)
|
||||
|
||||
self.model = Model(
|
||||
ggml_model=f"./models/llama_cpp/{self.config['model']}",
|
||||
ggml_model=f"./models/llama_cpp/{self.config['model']}",
|
||||
prompt_context="", prompt_prefix="", prompt_suffix="", anti_prompts= [],
|
||||
n_ctx=self.config['ctx_size'],
|
||||
seed=self.config['seed'],
|
||||
)
|
||||
@ -53,17 +54,17 @@ class LLAMACPP(GPTBackend):
|
||||
verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False.
|
||||
"""
|
||||
try:
|
||||
self.model.generate(
|
||||
prompt,
|
||||
new_text_callback=new_text_callback,
|
||||
n_predict=n_predict,
|
||||
temp=self.config['temp'],
|
||||
top_k=self.config['top_k'],
|
||||
top_p=self.config['top_p'],
|
||||
repeat_penalty=self.config['repeat_penalty'],
|
||||
repeat_last_n = self.config['repeat_last_n'],
|
||||
n_threads=self.config['n_threads'],
|
||||
verbose=verbose
|
||||
)
|
||||
self.model.reset()
|
||||
for tok in self.model.generate(prompt,
|
||||
n_predict=n_predict,
|
||||
temp=self.config['temp'],
|
||||
top_k=self.config['top_k'],
|
||||
top_p=self.config['top_p'],
|
||||
repeat_penalty=self.config['repeat_penalty'],
|
||||
repeat_last_n = self.config['repeat_last_n'],
|
||||
n_threads=self.config['n_threads'],
|
||||
):
|
||||
if not new_text_callback(tok):
|
||||
return
|
||||
except Exception as ex:
|
||||
print(ex)
|
@ -30,13 +30,13 @@ personality_conditionning: |
|
||||
welcome_message: "Welcome! I am GPT4All A free and open assistant. What can I do for you today?"
|
||||
|
||||
# This prefix is added at the beginning of any message input by the user
|
||||
user_message_prefix: "### Human:"
|
||||
user_message_prefix: "### Human:\n"
|
||||
|
||||
# A text to put between user and chatbot messages
|
||||
link_text: "\n"
|
||||
|
||||
# This prefix is added at the beginning of any message output by the ai
|
||||
ai_message_prefix: "### Assistant:"
|
||||
ai_message_prefix: "### Assistant:\n"
|
||||
|
||||
# Here is the list of extensions this personality requires
|
||||
dependencies: []
|
||||
|
@ -196,45 +196,18 @@ class GPT4AllAPI():
|
||||
print(text, end="")
|
||||
sys.stdout.flush()
|
||||
|
||||
if self.chatbot_bindings.inline:
|
||||
self.bot_says += text
|
||||
if not self.personality["user_message_prefix"].lower() in self.bot_says.lower():
|
||||
self.socketio.emit('message', {'data': self.bot_says});
|
||||
if self.cancel_gen:
|
||||
print("Generation canceled")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].lower())
|
||||
print("The model is halucinating")
|
||||
self.socketio.emit('final', {'data': self.bot_says})
|
||||
self.bot_says += text
|
||||
if not self.personality["user_message_prefix"].strip().lower() in self.bot_says.lower():
|
||||
self.socketio.emit('message', {'data': self.bot_says});
|
||||
if self.cancel_gen:
|
||||
print("Generation canceled")
|
||||
return False
|
||||
else:
|
||||
self.full_text += text
|
||||
if self.is_bot_text_started:
|
||||
self.bot_says += text
|
||||
if not self.personality["user_message_prefix"].lower() in self.bot_says.lower():
|
||||
self.socketio.emit('message', {'data': self.bot_says});
|
||||
#self.socketio.emit('message', {'data': text});
|
||||
if self.cancel_gen:
|
||||
print("Generation canceled")
|
||||
self.socketio.emit('final', {'data': self.bot_says})
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].lower())
|
||||
print("The model is halucinating")
|
||||
self.socketio.emit('final', {'data': self.bot_says})
|
||||
self.cancel_gen=True
|
||||
return False
|
||||
|
||||
#if self.current_message in self.full_text:
|
||||
if len(self.discussion_messages) < len(self.full_text):
|
||||
self.is_bot_text_started = True
|
||||
else:
|
||||
self.socketio.emit('waiter', {'wait': (len(self.discussion_messages)-len(self.full_text))/len(self.discussion_messages)});
|
||||
return True
|
||||
else:
|
||||
self.bot_says = self.remove_text_from_string(self.bot_says, self.personality["user_message_prefix"].strip())
|
||||
print("The model is halucinating")
|
||||
return False
|
||||
|
||||
def generate_message(self):
|
||||
self.generating=True
|
||||
|
@ -4,11 +4,13 @@ nomic
|
||||
pytest
|
||||
pyyaml
|
||||
markdown
|
||||
pyllamacpp==1.0.7
|
||||
pyllamacpp==2.0.0
|
||||
gpt4all-j==0.2.1
|
||||
--find-links https://download.pytorch.org/whl/cu117
|
||||
torch==2.0.0
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers
|
||||
accelerate
|
||||
accelerate
|
||||
gevent
|
||||
gevent-websocket
|
||||
|
@ -13,9 +13,6 @@ function send_message(service_name, parameters){
|
||||
globals.socket = socket
|
||||
globals.is_generating = false
|
||||
socket.on('connect', function() {
|
||||
globals.sendbtn.style.display="block";
|
||||
globals.waitAnimation.style.display="none";
|
||||
globals.stopGeneration.style.display = "none";
|
||||
entry_counter = 0;
|
||||
if(globals.is_generating){
|
||||
globals.socket.disconnect()
|
||||
@ -29,6 +26,10 @@ function send_message(service_name, parameters){
|
||||
socket.on('disconnect', function() {
|
||||
console.log("disconnected")
|
||||
entry_counter = 0;
|
||||
console.log("Disconnected")
|
||||
globals.sendbtn.style.display="block";
|
||||
globals.waitAnimation.style.display="none";
|
||||
globals.stopGeneration.style.display = "none";
|
||||
});
|
||||
|
||||
|
||||
@ -40,10 +41,11 @@ function send_message(service_name, parameters){
|
||||
}
|
||||
globals.bot_msg.setSender(msg.bot);
|
||||
globals.bot_msg.setID(msg.response_id);
|
||||
globals.bot_msg.messageTextElement.innerHTML = `Generating answer. Please satnd by...`;
|
||||
});
|
||||
|
||||
socket.on('waiter', function(msg) {
|
||||
globals.bot_msg.messageTextElement.innerHTML = `Remaining words ${Math.floor(msg.wait * 100)}%`;
|
||||
globals.bot_msg.messageTextElement.innerHTML = `Generating answer. Please satnd by...`;
|
||||
});
|
||||
|
||||
socket.on('message', function(msg) {
|
||||
@ -59,10 +61,8 @@ function send_message(service_name, parameters){
|
||||
text = msg.data;
|
||||
globals.bot_msg.hiddenElement.innerHTML = text
|
||||
globals.bot_msg.messageTextElement.innerHTML = marked.marked(text)
|
||||
socket.disconnect()
|
||||
|
||||
globals.sendbtn.style.display="block";
|
||||
globals.waitAnimation.style.display="none";
|
||||
globals.stopGeneration.style.display = "none";
|
||||
});
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user