enhanced quality of prompting

This commit is contained in:
saloui 2023-09-14 10:38:33 +02:00
parent 0d1dfaaf1d
commit e460dc99d3
5 changed files with 50 additions and 17 deletions

View File

@ -10,9 +10,7 @@
from flask import request
from datetime import datetime
from api.db import DiscussionsDB, Discussion
from api.helpers import compare_lists
from pathlib import Path
import importlib
from lollms.config import InstallOption
from lollms.types import MSG_TYPE, SENDER_TYPES
from lollms.personality import AIPersonality, PersonalityBuilder
@ -21,20 +19,17 @@ from lollms.paths import LollmsPaths
from lollms.helpers import ASCIIColors, trace_exception
from lollms.app import LollmsApplication
from lollms.utilities import File64BitsManager, PromptReshaper
import multiprocessing as mp
import threading
import time
import requests
from tqdm import tqdm
import traceback
import sys
from lollms.terminal import MainMenu
import urllib
import gc
import ctypes
from functools import partial
import json
import shutil
import re
import string
def terminate_thread(thread):
if thread:
@ -162,7 +157,7 @@ class LoLLMsAPPI(LollmsApplication):
def connect():
#Create a new connection information
self.connections[request.sid] = {
"current_discussion":None,
"current_discussion":self.db.load_last_discussion(),
"generated_text":"",
"cancel_generation": False,
"generation_thread": None,
@ -1001,6 +996,20 @@ class LoLLMsAPPI(LollmsApplication):
self.nb_received_tokens = 0
def clean_string(self, input_string):
# Remove extra spaces by replacing multiple spaces with a single space
#cleaned_string = re.sub(r'\s+', ' ', input_string)
# Remove extra line breaks by replacing multiple consecutive line breaks with a single line break
cleaned_string = re.sub(r'\n\s*\n', '\n', input_string)
# Create a string containing all punctuation characters
punctuation_chars = string.punctuation
# Define a regular expression pattern to match and remove non-alphanumeric characters
pattern = f'[^a-zA-Z0-9\s{re.escape(punctuation_chars)}]' # This pattern matches any character that is not a letter, digit, space, or punctuation
# Use re.sub to replace the matched characters with an empty string
cleaned_string = re.sub(pattern, '', cleaned_string)
return cleaned_string
def prepare_query(self, client_id, message_id=-1, is_continue=False):
messages = self.connections[client_id]["current_discussion"].get_messages()
full_message_list = []
@ -1033,10 +1042,10 @@ class LoLLMsAPPI(LollmsApplication):
conditionning = self.personality.personality_conditioning
if self.config["override_personality_model_parameters"]:
conditionning = conditionning+ "!@>user description:\nName:"+self.config["user_name"]+"\n"+self.config["user_description"]+"\n"
conditionning = conditionning+ "\n!@>user description:\nName:"+self.config["user_name"]+"\n"+self.config["user_description"]+"\n"
if len(self.personality.files)>0 and self.personality.vectorizer:
pr = PromptReshaper("!@>document chunks:\n{{doc}}\n{{conditionning}}\n{{content}}")
pr = PromptReshaper("{{conditionning}}\n!@>document chunks:\n{{doc}}\n{{content}}")
emb = self.personality.vectorizer.embed_query(message.content)
docs, sorted_similarities = self.personality.vectorizer.recover_text(emb, top_k=self.config.data_vectorization_nb_chunks)
str_docs = ""
@ -1053,15 +1062,13 @@ class LoLLMsAPPI(LollmsApplication):
"conditionning":conditionning,
"content":discussion_messages
}, self.model.tokenize, self.model.detokenize, self.config.ctx_size, place_holders_to_sacrifice=["content"])
# remove extra returns
discussion_messages = self.clean_string(discussion_messages)
tokens = self.model.tokenize(discussion_messages)
if self.config["debug"]:
ASCIIColors.yellow(discussion_messages)
ASCIIColors.info(f"prompt size:{len(tokens)} tokens")
return discussion_messages, message.content, tokens
def get_discussion_to(self, client_id, message_id=-1):

View File

@ -455,7 +455,8 @@ class Discussion:
def __init__(self, discussion_id, discussions_db:DiscussionsDB):
self.discussion_id = discussion_id
self.discussions_db = discussions_db
self.messages = []
self.messages = self.get_messages()
self.current_message = self.messages[-1]
def load_message(self, id):
"""Gets a list of messages information

11
app.py
View File

@ -209,6 +209,10 @@ class LoLLMsWebUI(LoLLMsAPPI):
# Endpoints
# =========================================================================================
self.add_endpoint(
"/get_current_personality_files_list", "get_current_personality_files_list", self.get_current_personality_files_list, methods=["GET"]
)
self.add_endpoint("/start_training", "start_training", self.start_training, methods=["POST"])
self.add_endpoint("/get_lollms_version", "get_lollms_version", self.get_lollms_version, methods=["GET"])
@ -1372,7 +1376,12 @@ class LoLLMsWebUI(LoLLMsAPPI):
ASCIIColors.info("")
ASCIIColors.info("")
run_update_script(self.args)
def get_current_personality_files_list(self):
if self.personality is None:
return jsonify({"state":False, "error":"No personality selected"})
return jsonify({"state":True, "files":[Path(f).name for f in self.personality.files]})
def start_training(self):
if self.config.enable_gpu:
if not self.lollms_paths.gptqlora_path.exists():

View File

@ -19,6 +19,18 @@ var socket = io.connect(location.protocol + '//' + document.domain + ':' + locat
socket.on('connect', function() {
console.log("Disconnected")
});
// Handle reconnection attempt failure
socket.on('reconnect_failed', () => {
console.log('All reconnection attempts failed');
// You can perform any custom actions or error handling here
});
// Handle reconnection attempt
socket.on('reconnect_attempt', () => {
reconnectionAttempt++;
console.log(`Reconnection attempt ${reconnectionAttempt}...`);
// You can perform custom actions or error handling for each attempt here
});
socket.on('disconnect', function() {
console.log("Disconnected")
});

View File

@ -8,7 +8,11 @@ import io from 'socket.io-client';
// fixes issues when people not hosting this site on local network
const URL = process.env.NODE_ENV === "production" ? undefined : (import.meta.env.VITE_LOLLMS_API);
const socket = new io(URL);
const socket = new io(URL,{
reconnection: true, // Enable reconnection
reconnectionAttempts: 3, // Maximum reconnection attempts
reconnectionDelay: 1000, // Delay between reconnection attempts (in milliseconds)
});
// const app = createApp(/* your root component */);