From 126b85868e6192bbb0b5f80953a4087c02c866be Mon Sep 17 00:00:00 2001 From: saloui Date: Thu, 13 Jul 2023 16:49:54 +0200 Subject: [PATCH] Preparing for separated connections --- api/__init__.py | 104 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/api/__init__.py b/api/__init__.py index 4bbc5f4d..3e34219c 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -30,6 +30,19 @@ import sys from lollms.console import MainMenu import urllib import gc +import ctypes +from functools import partial + +def terminate_thread(thread): + if not thread.is_alive(): + return + + thread_id = thread.ident + exc = ctypes.py_object(SystemExit) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, exc) + if res > 1: + ctypes.pythonapi.PyThreadState_SetAsyncExc(thread_id, None) + raise SystemError("Failed to terminate the thread.") __author__ = "parisneo" __github__ = "https://github.com/ParisNeo/lollms-webui" @@ -96,6 +109,7 @@ class LoLLMsAPPI(LollmsApplication): super().__init__("Lollms_webui",config, lollms_paths) self.is_ready = True + self.socketio = socketio self.config_file_path = config_file_path self.cancel_gen = False @@ -122,17 +136,31 @@ class LoLLMsAPPI(LollmsApplication): # This is used to keep track of messages self.full_message_list = [] - self.current_room_id = None self.download_infos={} + + self.connections = {} + # ========================================================================================= # Socket IO stuff # ========================================================================================= @socketio.on('connect') def connect(): + #Create a new connection information + self.connections[request.sid] = { + "current_discussion":None, + "generated_text":"", + "cancel_generation": False, + "generation_thread": None + } ASCIIColors.success(f'Client {request.sid} connected') @socketio.on('disconnect') def disconnect(): + try: + del self.connections[request.sid] + except Exception as ex: + pass + ASCIIColors.error(f'Client {request.sid} disconnected') @@ -146,13 +174,12 @@ class LoLLMsAPPI(LollmsApplication): self.socketio.emit('canceled', { 'status': True }, - room=self.current_room_id + room=request.sid ) @socketio.on('install_model') def install_model(data): - room_id = request.sid - + room_id = request.sid def install_model_(): print("Install model triggered") @@ -347,13 +374,21 @@ class LoLLMsAPPI(LollmsApplication): @socketio.on('cancel_generation') def cancel_generation(): + client_id = request.sid self.cancel_gen = True + #kill thread + ASCIIColors.error(f'Client {request.sid} requested cancelling generation') + terminate_thread(self.connections[client_id]['generation_thread']) ASCIIColors.error(f'Client {request.sid} canceled generation') + self.cancel_gen = False @socketio.on('generate_msg') def generate_msg(data): - self.current_room_id = request.sid + client_id = request.sid + self.connections[client_id]["generated_text"]="" + self.connections[client_id]["cancel_generation"]=False + if self.is_ready: if self.current_discussion is None: if self.db.does_last_discussion_have_messages(): @@ -371,16 +406,17 @@ class LoLLMsAPPI(LollmsApplication): self.current_user_message_id = message_id ASCIIColors.green("Starting message generation by"+self.personality.name) - - task = self.socketio.start_background_task(self.start_message_generation, message, message_id) + self.connections[client_id]['generation_thread'] = threading.Thread(target=self.start_message_generation, args=(message, message_id, client_id)) + self.connections[client_id]['generation_thread'].start() + self.socketio.sleep(0.01) ASCIIColors.info("Started generation task") - #tpe = threading.Thread(target=self.start_message_generation, args=(message, message_id)) + #tpe = threading.Thread(target=self.start_message_generation, args=(message, message_id, client_id)) #tpe.start() else: - self.socketio.emit("buzzy", {"message":"I am buzzy. Come back later."}, room=self.current_room_id) + self.socketio.emit("buzzy", {"message":"I am buzzy. Come back later."}, room=client_id) self.socketio.sleep(0.01) - ASCIIColors.warning(f"OOps request {self.current_room_id} refused!! Server buzy") + ASCIIColors.warning(f"OOps request {client_id} refused!! Server buzy") self.socketio.emit('infos', { "status":'model_not_ready', @@ -397,17 +433,19 @@ class LoLLMsAPPI(LollmsApplication): 'personality': self.current_discussion.current_message_personality, 'created_at': self.current_discussion.current_message_created_at, 'finished_generating_at': self.current_discussion.current_message_finished_generating_at, - }, room=self.current_room_id + }, room=client_id ) self.socketio.sleep(0.01) @socketio.on('generate_msg_from') def handle_connection(data): + client_id = request.sid message_id = int(data['id']) message = data["prompt"] self.current_user_message_id = message_id - tpe = threading.Thread(target=self.start_message_generation, args=(message, message_id)) - tpe.start() + self.connections[client_id]['generation_thread'] = threading.Thread(target=self.start_message_generation, args=(message, message_id, client_id)) + self.connections[client_id]['generation_thread'].start() + # generation status self.generating=False ASCIIColors.blue(f"Your personal data is stored here :",end="") @@ -416,11 +454,13 @@ class LoLLMsAPPI(LollmsApplication): @socketio.on('continue_generate_msg_from') def handle_connection(data): + client_id = request.sid message_id = int(data['id']) message = data["prompt"] self.current_user_message_id = message_id - tpe = threading.Thread(target=self.start_message_generation, args=(message, message_id)) - tpe.start() + self.connections[client_id]['generation_thread'] = threading.Thread(target=self.start_message_generation, args=(message, message_id, client_id)) + self.connections[client_id]['generation_thread'].start() + # generation status self.generating=False ASCIIColors.blue(f"Your personal data is stored here :",end="") @@ -653,7 +693,7 @@ class LoLLMsAPPI(LollmsApplication): return string - def process_chunk(self, chunk, message_type:MSG_TYPE): + def process_chunk(self, chunk, message_type:MSG_TYPE, client_id): """ 0 : a regular message 1 : a notification message @@ -688,7 +728,7 @@ class LoLLMsAPPI(LollmsApplication): 'ai_message_id':self.current_ai_message_id, 'discussion_id':self.current_discussion.discussion_id, 'message_type': MSG_TYPE.MSG_TYPE_FULL.value - }, room=self.current_room_id + }, room=client_id ) self.socketio.sleep(0.01) self.current_discussion.update_message(self.current_ai_message_id, self.current_generated_text) @@ -715,7 +755,7 @@ class LoLLMsAPPI(LollmsApplication): 'ai_message_id':self.current_ai_message_id, 'discussion_id':self.current_discussion.discussion_id, 'message_type': message_type.value - }, room=self.current_room_id + }, room=client_id ) self.socketio.sleep(0.01) return True @@ -727,7 +767,7 @@ class LoLLMsAPPI(LollmsApplication): 'ai_message_id':self.current_ai_message_id, 'discussion_id':self.current_discussion.discussion_id, 'message_type': message_type.value - }, room=self.current_room_id + }, room=client_id ) self.socketio.sleep(0.01) @@ -738,8 +778,9 @@ class LoLLMsAPPI(LollmsApplication): if self.personality.processor is not None: ASCIIColors.success("Running workflow") try: - output = self.personality.processor.run_workflow( prompt, full_prompt, self.process_chunk) - self.process_chunk(output, MSG_TYPE.MSG_TYPE_FULL) + output = self.personality.processor.run_workflow( prompt, full_prompt, callback) + if callback: + callback(output, MSG_TYPE.MSG_TYPE_FULL) except Exception as ex: # Catch the exception and get the traceback as a list of strings traceback_lines = traceback.format_exception(type(ex), ex, ex.__traceback__) @@ -747,7 +788,8 @@ class LoLLMsAPPI(LollmsApplication): traceback_text = ''.join(traceback_lines) ASCIIColors.error(f"Workflow run failed.\nError:{ex}") ASCIIColors.error(traceback_text) - self.process_chunk(f"Workflow run failed\nError:{ex}", MSG_TYPE.MSG_TYPE_EXCEPTION) + if callback: + callback(f"Workflow run failed\nError:{ex}", MSG_TYPE.MSG_TYPE_EXCEPTION) print("Finished executing the workflow") return @@ -793,8 +835,8 @@ class LoLLMsAPPI(LollmsApplication): output = "" return output - def start_message_generation(self, message, message_id, is_continue=False): - ASCIIColors.info(f"Text generation requested by client: {self.current_room_id}") + def start_message_generation(self, message, message_id, client_id, is_continue=False): + ASCIIColors.info(f"Text generation requested by client: {client_id}") # send the message to the bot print(f"Received message : {message}") if self.current_discussion: @@ -826,7 +868,7 @@ class LoLLMsAPPI(LollmsApplication): 'personality': self.current_discussion.current_message_personality, 'created_at': self.current_discussion.current_message_created_at, 'finished_generating_at': self.current_discussion.current_message_finished_generating_at, - }, room=self.current_room_id + }, room=client_id ) self.socketio.sleep(0.01) @@ -834,7 +876,13 @@ class LoLLMsAPPI(LollmsApplication): self.discussion_messages, self.current_message, tokens = self.prepare_query(message_id, is_continue) self.prepare_reception() self.generating = True - self.generate(self.discussion_messages, self.current_message, n_predict = self.config.ctx_size-len(tokens)-1, callback=self.process_chunk) + self.generate( + self.discussion_messages, + self.current_message, + n_predict = self.config.ctx_size-len(tokens)-1, + callback=partial(self.process_chunk,client_id = client_id) + + ) print() print("## Done Generation ##") print() @@ -863,7 +911,7 @@ class LoLLMsAPPI(LollmsApplication): 'created_at': self.current_discussion.current_message_created_at, 'finished_generating_at': self.current_discussion.current_message_finished_generating_at, - }, room=self.current_room_id + }, room=client_id ) self.socketio.sleep(0.01) @@ -880,7 +928,7 @@ class LoLLMsAPPI(LollmsApplication): 'ai_message_id':self.current_ai_message_id, 'discussion_id':0, 'message_type': MSG_TYPE.MSG_TYPE_EXCEPTION.value - }, room=self.current_room_id + }, room=client_id ) print() return ""