From fb946af84281ec8d967c9992e6c9503f20962e2a Mon Sep 17 00:00:00 2001 From: saloui Date: Thu, 27 Apr 2023 16:44:22 +0200 Subject: [PATCH] Moved to socketio Update README.md --- README.md | 2 +- app.py | 86 ++++++++++++++++++++------------- pyGpt4All/api.py | 3 +- static/js/main.js | 118 ++++++++++++++++------------------------------ 4 files changed, 97 insertions(+), 112 deletions(-) diff --git a/README.md b/README.md index 4ee1a23e..629766ce 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ![GitHub forks](https://img.shields.io/github/forks/nomic-ai/GPT4All-ui) [![Discord](https://img.shields.io/discord/1092918764925882418?color=7289da&label=Discord&logo=discord&logoColor=ffffff)](https://discord.gg/4rR282WJb6) -This is a Flask web application that provides a chat UI for interacting with [llamacpp](https://github.com/ggerganov/llama.cpp) based chatbots such as [GPT4all](https://github.com/nomic-ai/gpt4all), vicuna etc... +This is a Flask web application that provides a chat UI for interacting with [llamacpp](https://github.com/ggerganov/llama.cpp), gpt-j, gpt-q as well as Hugging face based language models uch as [GPT4all](https://github.com/nomic-ai/gpt4all), vicuna etc... Follow us on our [Discord Server](https://discord.gg/4rR282WJb6). diff --git a/app.py b/app.py index 84b2fe06..bac0f8a7 100644 --- a/app.py +++ b/app.py @@ -32,7 +32,7 @@ from flask import ( stream_with_context, send_from_directory ) -from flask_socketio import SocketIO +from flask_socketio import SocketIO, emit from pathlib import Path import gc app = Flask("GPT4All-WebUI", static_url_path="/static", static_folder="static") @@ -150,6 +150,53 @@ class Gpt4AllWebUI(GPT4AllAPI): self.add_endpoint( "/help", "help", self.help, methods=["GET"] ) + + + + # Socket IO stuff + @socketio.on('connect') + def test_connect(): + print('Client connected') + + @socketio.on('disconnect') + def test_disconnect(): + print('Client disconnected') + + @socketio.on('stream-text') + def handle_stream_text(data): + text = data['prompt'] + words = text.split() + for word in words: + emit('stream-word', {'word': word}) + time.sleep(1) # sleep for 1 second to simulate processing time + emit('stream-end') + + + @socketio.on('connected') + def handle_connection(data): + if "data" in data and data["data"]=='Connected!': + return + if self.current_discussion is None: + if self.db.does_last_discussion_have_messages(): + self.current_discussion = self.db.create_discussion() + else: + self.current_discussion = self.db.load_last_discussion() + + message = data["prompt"] + message_id = self.current_discussion.add_message( + "user", message, parent=self.current_message_id + ) + message = data["prompt"] + self.current_message_id = message_id + tpe = threading.Thread(target=self.parse_to_prompt_stream, args=(message, message_id)) + tpe.start() + + # self.parse_to_prompt_stream(message, message_id) + + #self.socketio.emit('message', {'data': 'WebSocket connected!'}) + + #for i in range(10): + # socketio.emit('message', {'data': 'Message ' + str(i)}) def list_backends(self): backends_dir = Path('./backends') # replace with the actual path to the models folder @@ -242,7 +289,6 @@ class Gpt4AllWebUI(GPT4AllAPI): return jsonify({"discussion_text":self.get_discussion_to()}) - @stream_with_context def parse_to_prompt_stream(self, message, message_id): bot_says = "" @@ -252,8 +298,7 @@ class Gpt4AllWebUI(GPT4AllAPI): response_id = self.current_discussion.add_message( self.personality["name"], "", parent = message_id ) # first the content is empty, but we'll fill it at the end - yield ( - json.dumps( + socketio.emit('infos', { "type": "input_message_infos", "bot": self.personality["name"], @@ -262,8 +307,8 @@ class Gpt4AllWebUI(GPT4AllAPI): "id": message_id, "response_id": response_id, } - ) - ) + ) + # prepare query and reception self.discussion_messages = self.prepare_query(message_id) @@ -271,41 +316,18 @@ class Gpt4AllWebUI(GPT4AllAPI): self.generating = True # app.config['executor'] = ThreadPoolExecutor(max_workers=1) # app.config['executor'].submit(self.generate_message) - tpe = threading.Thread(target=self.generate_message) - tpe.start() - while self.generating: - try: - while not self.text_queue.empty(): - value = self.text_queue.get(False) - if self.cancel_gen: - self.generating = False - break - yield value - time.sleep(0) - except Exception as ex: - print(f"Exception {ex}") - time.sleep(0.1) - if self.cancel_gen: - self.generating = False - tpe = None - gc.collect() + self.generate_message() print("## Done ##") self.current_discussion.update_message(response_id, self.bot_says) self.full_message_list.append(self.bot_says) bot_says = markdown.markdown(self.bot_says) - yield "FINAL:"+bot_says + socketio.emit('final', {'data': bot_says}) self.cancel_gen = False return bot_says - # Socket IO stuff - @socketio.on('connected') - def handle_connection(self, data): - self.socketio.emit('message', {'data': 'WebSocket connected!'}) - - for i in range(10): - socketio.emit('message', {'data': 'Message ' + str(i)}) + def generate(self): diff --git a/pyGpt4All/api.py b/pyGpt4All/api.py index 5caaf66d..c9294e32 100644 --- a/pyGpt4All/api.py +++ b/pyGpt4All/api.py @@ -202,7 +202,7 @@ class GPT4AllAPI(): if self.chatbot_bindings.inline: self.bot_says += text if not self.personality["user_message_prefix"].lower() in self.bot_says.lower(): - self.text_queue.put(text) + self.socketio.emit('message', {'data': text}); if self.cancel_gen: print("Generation canceled") return False @@ -235,7 +235,6 @@ class GPT4AllAPI(): def generate_message(self): self.generating=True - self.text_queue=Queue() gc.collect() total_n_predict = self.config['n_predict'] print(f"Generating {total_n_predict} outputs... ") diff --git a/static/js/main.js b/static/js/main.js index 12dfb1ec..46ed0055 100644 --- a/static/js/main.js +++ b/static/js/main.js @@ -38,85 +38,49 @@ function update_main(){ // scroll to bottom of chat window chatWindow.scrollTop = chatWindow.scrollHeight; - fetch('/generate', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ message }) - }).then(function(response) { - const stream = new ReadableStream({ - start(controller) { - const reader = response.body.getReader(); - function push() { - reader.read().then(function(result) { - if (result.done) { - sendbtn.style.display="block"; - waitAnimation.style.display="none"; - stopGeneration.style.display = "none"; - console.log(result) - controller.close(); - return; - } - controller.enqueue(result.value); - push(); - }) - } - push(); - } - }); - const textDecoder = new TextDecoder(); - const readableStreamDefaultReader = stream.getReader(); - let entry_counter = 0 - function readStream() { - readableStreamDefaultReader.read().then(function(result) { - if (result.done) { - return; - } + var socket = io.connect('http://' + document.domain + ':' + location.port); + entry_counter = 0; - text = textDecoder.decode(result.value); - - // The server will first send a json containing information about the message just sent - if(entry_counter==0) - { - // We parse it and - infos = JSON.parse(text); - - user_msg.setSender(infos.user); - user_msg.setMessage(infos.message); - user_msg.setID(infos.id); - bot_msg.setSender(infos.bot); - bot_msg.setID(infos.response_id); - - bot_msg.messageTextElement; - bot_msg.hiddenElement; - entry_counter ++; - } - else{ - entry_counter ++; - prefix = "FINAL:"; - if(text.startsWith(prefix)){ - text = text.substring(prefix.length); - bot_msg.hiddenElement.innerHTML = text - bot_msg.messageTextElement.innerHTML = text - } - else{ - // For the other enrtries, these are just the text of the chatbot - txt = bot_msg.hiddenElement.innerHTML; - txt += text - bot_msg.hiddenElement.innerHTML = txt; - bot_msg.messageTextElement.innerHTML = txt; - // scroll to bottom of chat window - chatWindow.scrollTop = chatWindow.scrollHeight; - - } - } - - readStream(); - }); - } - readStream(); + socket.on('connect', function() { + socket.emit('connected', {prompt: message}); + entry_counter = 0; }); + socket.on('disconnect', function() { + + entry_counter = 0; + }); + socket.on() + + socket.on('infos', function(msg) { + user_msg.setSender(msg.user); + user_msg.setMessage(msg.message); + user_msg.setID(msg.id); + bot_msg.setSender(msg.bot); + bot_msg.setID(msg.response_id); + }); + + socket.on('message', function(msg) { + text = msg.data; + console.log(text) + // For the other enrtries, these are just the text of the chatbot + txt = bot_msg.hiddenElement.innerHTML; + txt += text + bot_msg.hiddenElement.innerHTML = txt; + bot_msg.messageTextElement.innerHTML = txt; + // scroll to bottom of chat window + chatWindow.scrollTop = chatWindow.scrollHeight; + }); + + socket.on('final',function(msg){ + text = msg.data; + bot_msg.hiddenElement.innerHTML = text + bot_msg.messageTextElement.innerHTML = text + sendbtn.style.display="block"; + waitAnimation.style.display="none"; + stopGeneration.style.display = "none"; + socket.disconnect(); + }); + //socket.emit('stream-text', {text: text}); } chatForm.addEventListener('submit', event => { event.preventDefault();