From e460dc99d3526b1c9278f845d6665cf852ab6cde Mon Sep 17 00:00:00 2001 From: saloui Date: Thu, 14 Sep 2023 10:38:33 +0200 Subject: [PATCH] enhanced quality of prompting --- api/__init__.py | 35 +++++++++++++++++++++-------------- api/db.py | 3 ++- app.py | 11 ++++++++++- static/js/websocket.js | 12 ++++++++++++ web/src/services/websocket.js | 6 +++++- 5 files changed, 50 insertions(+), 17 deletions(-) diff --git a/api/__init__.py b/api/__init__.py index 872f21b3..5611b1a7 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -10,9 +10,7 @@ from flask import request from datetime import datetime from api.db import DiscussionsDB, Discussion -from api.helpers import compare_lists from pathlib import Path -import importlib from lollms.config import InstallOption from lollms.types import MSG_TYPE, SENDER_TYPES from lollms.personality import AIPersonality, PersonalityBuilder @@ -21,20 +19,17 @@ from lollms.paths import LollmsPaths from lollms.helpers import ASCIIColors, trace_exception from lollms.app import LollmsApplication from lollms.utilities import File64BitsManager, PromptReshaper -import multiprocessing as mp import threading -import time -import requests from tqdm import tqdm import traceback import sys -from lollms.terminal import MainMenu -import urllib import gc import ctypes from functools import partial import json import shutil +import re +import string def terminate_thread(thread): if thread: @@ -162,7 +157,7 @@ class LoLLMsAPPI(LollmsApplication): def connect(): #Create a new connection information self.connections[request.sid] = { - "current_discussion":None, + "current_discussion":self.db.load_last_discussion(), "generated_text":"", "cancel_generation": False, "generation_thread": None, @@ -1001,6 +996,20 @@ class LoLLMsAPPI(LollmsApplication): self.nb_received_tokens = 0 + def clean_string(self, input_string): + # Remove extra spaces by replacing multiple spaces with a single space + #cleaned_string = re.sub(r'\s+', ' ', input_string) + + # Remove extra line breaks by replacing multiple consecutive line breaks with a single line break + cleaned_string = re.sub(r'\n\s*\n', '\n', input_string) + # Create a string containing all punctuation characters + punctuation_chars = string.punctuation + # Define a regular expression pattern to match and remove non-alphanumeric characters + pattern = f'[^a-zA-Z0-9\s{re.escape(punctuation_chars)}]' # This pattern matches any character that is not a letter, digit, space, or punctuation + + # Use re.sub to replace the matched characters with an empty string + cleaned_string = re.sub(pattern, '', cleaned_string) + return cleaned_string def prepare_query(self, client_id, message_id=-1, is_continue=False): messages = self.connections[client_id]["current_discussion"].get_messages() full_message_list = [] @@ -1033,10 +1042,10 @@ class LoLLMsAPPI(LollmsApplication): conditionning = self.personality.personality_conditioning if self.config["override_personality_model_parameters"]: - conditionning = conditionning+ "!@>user description:\nName:"+self.config["user_name"]+"\n"+self.config["user_description"]+"\n" + conditionning = conditionning+ "\n!@>user description:\nName:"+self.config["user_name"]+"\n"+self.config["user_description"]+"\n" if len(self.personality.files)>0 and self.personality.vectorizer: - pr = PromptReshaper("!@>document chunks:\n{{doc}}\n{{conditionning}}\n{{content}}") + pr = PromptReshaper("{{conditionning}}\n!@>document chunks:\n{{doc}}\n{{content}}") emb = self.personality.vectorizer.embed_query(message.content) docs, sorted_similarities = self.personality.vectorizer.recover_text(emb, top_k=self.config.data_vectorization_nb_chunks) str_docs = "" @@ -1053,15 +1062,13 @@ class LoLLMsAPPI(LollmsApplication): "conditionning":conditionning, "content":discussion_messages }, self.model.tokenize, self.model.detokenize, self.config.ctx_size, place_holders_to_sacrifice=["content"]) - + # remove extra returns + discussion_messages = self.clean_string(discussion_messages) tokens = self.model.tokenize(discussion_messages) if self.config["debug"]: ASCIIColors.yellow(discussion_messages) ASCIIColors.info(f"prompt size:{len(tokens)} tokens") - - - return discussion_messages, message.content, tokens def get_discussion_to(self, client_id, message_id=-1): diff --git a/api/db.py b/api/db.py index f2db3a5b..492aabc7 100644 --- a/api/db.py +++ b/api/db.py @@ -455,7 +455,8 @@ class Discussion: def __init__(self, discussion_id, discussions_db:DiscussionsDB): self.discussion_id = discussion_id self.discussions_db = discussions_db - self.messages = [] + self.messages = self.get_messages() + self.current_message = self.messages[-1] def load_message(self, id): """Gets a list of messages information diff --git a/app.py b/app.py index d4c726f5..cdaec250 100644 --- a/app.py +++ b/app.py @@ -209,6 +209,10 @@ class LoLLMsWebUI(LoLLMsAPPI): # Endpoints # ========================================================================================= + self.add_endpoint( + "/get_current_personality_files_list", "get_current_personality_files_list", self.get_current_personality_files_list, methods=["GET"] + ) + self.add_endpoint("/start_training", "start_training", self.start_training, methods=["POST"]) self.add_endpoint("/get_lollms_version", "get_lollms_version", self.get_lollms_version, methods=["GET"]) @@ -1372,7 +1376,12 @@ class LoLLMsWebUI(LoLLMsAPPI): ASCIIColors.info("") ASCIIColors.info("") run_update_script(self.args) - + + def get_current_personality_files_list(self): + if self.personality is None: + return jsonify({"state":False, "error":"No personality selected"}) + return jsonify({"state":True, "files":[Path(f).name for f in self.personality.files]}) + def start_training(self): if self.config.enable_gpu: if not self.lollms_paths.gptqlora_path.exists(): diff --git a/static/js/websocket.js b/static/js/websocket.js index ae6a7f52..4ebce7a0 100644 --- a/static/js/websocket.js +++ b/static/js/websocket.js @@ -19,6 +19,18 @@ var socket = io.connect(location.protocol + '//' + document.domain + ':' + locat socket.on('connect', function() { console.log("Disconnected") }); +// Handle reconnection attempt failure +socket.on('reconnect_failed', () => { + console.log('All reconnection attempts failed'); + // You can perform any custom actions or error handling here + }); + +// Handle reconnection attempt +socket.on('reconnect_attempt', () => { + reconnectionAttempt++; + console.log(`Reconnection attempt ${reconnectionAttempt}...`); + // You can perform custom actions or error handling for each attempt here + }); socket.on('disconnect', function() { console.log("Disconnected") }); diff --git a/web/src/services/websocket.js b/web/src/services/websocket.js index c1490f48..4064624e 100644 --- a/web/src/services/websocket.js +++ b/web/src/services/websocket.js @@ -8,7 +8,11 @@ import io from 'socket.io-client'; // fixes issues when people not hosting this site on local network const URL = process.env.NODE_ENV === "production" ? undefined : (import.meta.env.VITE_LOLLMS_API); -const socket = new io(URL); +const socket = new io(URL,{ + reconnection: true, // Enable reconnection + reconnectionAttempts: 3, // Maximum reconnection attempts + reconnectionDelay: 1000, // Delay between reconnection attempts (in milliseconds) + }); // const app = createApp(/* your root component */);