mirror of
https://github.com/ParisNeo/lollms.git
synced 2024-12-18 20:27:58 +00:00
Enhanced
This commit is contained in:
parent
d21ba70dff
commit
58a0c986c3
@ -14,7 +14,7 @@ class MyConversation(Conversation):
|
||||
self.menu.main_menu()
|
||||
full_discussion += self.personality.user_message_prefix+prompt+self.personality.link_text
|
||||
full_discussion += self.personality.ai_message_prefix
|
||||
def callback(text, type=None):
|
||||
def callback(text, type=None, metadata:dict={}):
|
||||
print(text, end="")
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stdout.flush()
|
||||
|
@ -7,7 +7,7 @@ class MyConversation(Conversation):
|
||||
|
||||
def start_conversation(self):
|
||||
prompt = "Once apon a time"
|
||||
def callback(text, type=None):
|
||||
def callback(text, type=None, metadata:dict={}):
|
||||
print(text, end="")
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stdout.flush()
|
||||
|
@ -7,6 +7,7 @@ from lollms.paths import LollmsPaths
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.terminal import MainMenu
|
||||
|
||||
from typing import Callable
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import yaml
|
||||
@ -133,7 +134,7 @@ Participating personalities:
|
||||
full_discussion = ""
|
||||
return full_discussion
|
||||
|
||||
def safe_generate(self, full_discussion:str, n_predict=None, callback=None):
|
||||
def safe_generate(self, full_discussion:str, n_predict=None, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""safe_generate
|
||||
|
||||
Args:
|
||||
@ -236,7 +237,7 @@ Participating personalities:
|
||||
self.personality.ai_message_prefix
|
||||
)
|
||||
|
||||
def callback(text, type:MSG_TYPE=None):
|
||||
def callback(text, type:MSG_TYPE=None, metadata:dict={}):
|
||||
if type == MSG_TYPE.MSG_TYPE_CHUNK:
|
||||
# Replace stdout with the default stdout
|
||||
sys.stdout = sys.__stdout__
|
||||
|
@ -12,6 +12,7 @@ from lollms.apps.console import MainMenu
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.apps.console import MainMenu
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.utilities import TextVectorizer
|
||||
from typing import List, Tuple
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
@ -20,6 +21,7 @@ import logging
|
||||
import yaml
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
def reset_all_installs(lollms_paths:LollmsPaths):
|
||||
ASCIIColors.info("Removeing all configuration files to force reinstall")
|
||||
ASCIIColors.info(f"Searching files from {lollms_paths.personal_configuration_path}")
|
||||
@ -322,6 +324,96 @@ class LoLLMsServer(LollmsApplication):
|
||||
emit('personality_add_failed', {'success':False, 'error': error_message}, room=request.sid)
|
||||
|
||||
|
||||
|
||||
@self.socketio.on('vectorize_text')
|
||||
def vectorize_text(parameters:dict):
|
||||
"""Vectorizes text
|
||||
|
||||
Args:
|
||||
parameters (dict): contains
|
||||
'chunk_size': the maximum size of a text chunk (512 by default)
|
||||
'vectorization_method': can be either "model_embedding" or "ftidf_vectorizer" (default is "ftidf_vectorizer")
|
||||
'payloads': a list of dicts. each entry has the following format
|
||||
{
|
||||
"path": the path to the document
|
||||
"text": the text of the document
|
||||
},
|
||||
'return_database': If true the vectorized database will be sent to the client (default is True)
|
||||
'database_path': the path to store the database (default is none)
|
||||
|
||||
returns a dict
|
||||
status: True if success and false if not
|
||||
if you asked for the database to be sent back you will ahve those fields too:
|
||||
embeddings: a dictionary containing the text chunks with their ids and embeddings
|
||||
"texts": a dictionary of text chunks for each embedding (use index for correspondance)
|
||||
"infos": extra information
|
||||
"vectorizer": The vectorize if this is using tfidf or none if it uses model
|
||||
|
||||
"""
|
||||
vectorization_method = parameters.get('vectorization_method',"ftidf_vectorizer")
|
||||
chunk_size = parameters.get("chunk_size",512)
|
||||
payloads = parameters["payloads"]
|
||||
database_path = parameters.get("database_path",None)
|
||||
return_database = parameters.get("return_database",True)
|
||||
if database_path is None and return_database is None:
|
||||
ASCIIColors.warning("Vectorization should either ask to save the database or to recover it. You didn't ask for any one!")
|
||||
emit('vectorized_db',{"status":False, "error":"Vectorization should either ask to save the database or to recover it. You didn't ask for any one!"})
|
||||
return
|
||||
tv = TextVectorizer(vectorization_method, self.model)
|
||||
for payload in payloads:
|
||||
tv.add_document(payload["path"],payload["text"],chunk_size=chunk_size)
|
||||
json_db = tv.toJson()
|
||||
if return_database:
|
||||
emit('vectorized_db',{**{"status":True}, **json_db})
|
||||
else:
|
||||
emit('vectorized_db',{"status":True})
|
||||
with open(database_path, "w") as file:
|
||||
json.dump(json_db, file, indent=4)
|
||||
|
||||
|
||||
@self.socketio.on('query_database')
|
||||
def query_database(parameters:dict):
|
||||
"""queries a database
|
||||
|
||||
Args:
|
||||
parameters (dict): contains
|
||||
'vectorization_method': can be either "model_embedding" or "ftidf_vectorizer"
|
||||
'database': a list of dicts. each entry has the following format
|
||||
{
|
||||
embeddings: a dictionary containing the text chunks with their ids and embeddings
|
||||
"texts": a dictionary of text chunks for each embedding (use index for correspondance)
|
||||
"infos": extra information
|
||||
"vectorizer": The vectorize if this is using tfidf or none if it uses model
|
||||
}
|
||||
'database_path': If supplied, the database is loaded from a path
|
||||
'query': a query to search in the database
|
||||
"""
|
||||
vectorization_method = parameters['vectorization_method']
|
||||
database = parameters.get("database",None)
|
||||
query = parameters.get("query",None)
|
||||
if query is None:
|
||||
ASCIIColors.error("No query given!")
|
||||
emit('vector_db_query',{"status":False, "error":"Please supply a query"})
|
||||
return
|
||||
|
||||
if database is None:
|
||||
database_path = parameters.get("database_path",None)
|
||||
if database_path is None:
|
||||
ASCIIColors.error("No database given!")
|
||||
emit('vector_db_query',{"status":False, "error":"You did not supply a database file nor a database content"})
|
||||
return
|
||||
else:
|
||||
with open(database_path, "r") as file:
|
||||
database = json.load(file)
|
||||
|
||||
tv = TextVectorizer(vectorization_method, self.model, database_dict=database)
|
||||
docs, sorted_similarities = tv.recover_text(tv.embed_query(query))
|
||||
emit('vectorized_db',{
|
||||
"chunks":docs,
|
||||
"refs":sorted_similarities
|
||||
})
|
||||
|
||||
|
||||
@self.socketio.on('list_active_personalities')
|
||||
def handle_list_active_personalities():
|
||||
personality_names = [p.name for p in self.personalities]
|
||||
@ -394,7 +486,7 @@ class LoLLMsServer(LollmsApplication):
|
||||
if personality_id==-1:
|
||||
# Raw text generation
|
||||
self.answer = {"full_text":""}
|
||||
def callback(text, message_type: MSG_TYPE):
|
||||
def callback(text, message_type: MSG_TYPE, metadata:dict={}):
|
||||
if message_type == MSG_TYPE.MSG_TYPE_CHUNK:
|
||||
ASCIIColors.success(f"generated:{len(self.answer['full_text'].split())} words", end='\r')
|
||||
self.answer["full_text"] = self.answer["full_text"] + text
|
||||
@ -467,7 +559,7 @@ class LoLLMsServer(LollmsApplication):
|
||||
|
||||
full_discussion = personality.personality_conditioning + ''.join(full_discussion_blocks)
|
||||
|
||||
def callback(text, message_type: MSG_TYPE):
|
||||
def callback(text, message_type: MSG_TYPE, metadata:dict={}):
|
||||
if message_type == MSG_TYPE.MSG_TYPE_CHUNK:
|
||||
self.answer["full_text"] = self.answer["full_text"] + text
|
||||
self.socketio.emit('text_chunk', {'chunk': text}, room=client_id)
|
||||
|
@ -120,7 +120,7 @@ Participating personalities:
|
||||
full_discussion = ""
|
||||
return full_discussion
|
||||
|
||||
def safe_generate(self, full_discussion:str, n_predict=None, callback=None):
|
||||
def safe_generate(self, full_discussion:str, n_predict=None, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""safe_generate
|
||||
|
||||
Args:
|
||||
|
@ -195,9 +195,9 @@ class LLMBinding:
|
||||
|
||||
|
||||
def generate(self,
|
||||
prompt:str,
|
||||
prompt:str,
|
||||
n_predict: int = 128,
|
||||
callback: Callable[[str], None] = None,
|
||||
callback: Callable[[str, int, dict], bool] = None,
|
||||
verbose: bool = False,
|
||||
**gpt_params ):
|
||||
"""Generates text out of a prompt
|
||||
@ -206,7 +206,7 @@ class LLMBinding:
|
||||
Args:
|
||||
prompt (str): The prompt to use for generation
|
||||
n_predict (int, optional): Number of tokens to prodict. Defaults to 128.
|
||||
callback (Callable[[str], None], optional): A callback function that is called everytime a new text element is generated. Defaults to None.
|
||||
callback (Callable[[str, int, dict], None], optional): A callback function that is called everytime a new text element is generated. Defaults to None.
|
||||
verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False.
|
||||
"""
|
||||
pass
|
||||
|
@ -17,6 +17,7 @@ import subprocess
|
||||
import yaml
|
||||
from lollms.helpers import ASCIIColors
|
||||
from lollms.types import MSG_TYPE
|
||||
from typing import Callable
|
||||
import json
|
||||
|
||||
|
||||
@ -54,7 +55,8 @@ class AIPersonality:
|
||||
model:LLMBinding=None,
|
||||
run_scripts=True,
|
||||
is_relative_path=True,
|
||||
installation_option:InstallOption=InstallOption.INSTALL_IF_NECESSARY
|
||||
installation_option:InstallOption=InstallOption.INSTALL_IF_NECESSARY,
|
||||
callback: Callable[[str, int, dict], bool]=None
|
||||
):
|
||||
"""
|
||||
Initialize an AIPersonality instance.
|
||||
@ -68,6 +70,7 @@ class AIPersonality:
|
||||
self.lollms_paths = lollms_paths
|
||||
self.model = model
|
||||
self.config = config
|
||||
self.callback = callback
|
||||
|
||||
self.files = []
|
||||
|
||||
@ -248,7 +251,7 @@ Date: {{date}}
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
if hasattr(module, "Processor"):
|
||||
self._processor = module.Processor(self)
|
||||
self._processor = module.Processor(self, callback=self.callback)
|
||||
else:
|
||||
self._processor = None
|
||||
else:
|
||||
@ -881,7 +884,7 @@ class StateMachine:
|
||||
|
||||
|
||||
|
||||
def process_state(self, command, full_context, callback=None):
|
||||
def process_state(self, command, full_context, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""
|
||||
Process the given command based on the current state.
|
||||
|
||||
@ -922,7 +925,8 @@ class APScript(StateMachine):
|
||||
self,
|
||||
personality :AIPersonality,
|
||||
personality_config :TypedConfig,
|
||||
states_dict :dict = {}
|
||||
states_dict :dict = {},
|
||||
callback = None
|
||||
) -> None:
|
||||
super().__init__(states_dict)
|
||||
self.files=[]
|
||||
@ -932,6 +936,7 @@ class APScript(StateMachine):
|
||||
self.configuration_file_path = self.personality.lollms_paths.personal_configuration_path/f"personality_{self.personality.personality_folder_name}.yaml"
|
||||
self.personality_config.config.file_path = self.configuration_file_path
|
||||
|
||||
self.callback = callback
|
||||
# Installation
|
||||
if (not self.configuration_file_path.exists() or self.installation_option==InstallOption.FORCE_INSTALL) and self.installation_option!=InstallOption.NEVER_INSTALL:
|
||||
self.install()
|
||||
@ -990,8 +995,7 @@ class APScript(StateMachine):
|
||||
else:
|
||||
ASCIIColors.error("Pytorch installed successfully!!")
|
||||
|
||||
def add_file(self, path, callback=None):
|
||||
self.callback=callback
|
||||
def add_file(self, path):
|
||||
self.files.append(path)
|
||||
return True
|
||||
|
||||
@ -1104,7 +1108,7 @@ class APScript(StateMachine):
|
||||
else:
|
||||
return False
|
||||
|
||||
def run_workflow(self, prompt:str, previous_discussion_text:str="", callback=None):
|
||||
def run_workflow(self, prompt:str, previous_discussion_text:str="", callback: Callable[[str, int, dict], bool]=None):
|
||||
"""
|
||||
Runs the workflow for processing the model input and output.
|
||||
|
||||
@ -1121,7 +1125,7 @@ class APScript(StateMachine):
|
||||
"""
|
||||
return None
|
||||
|
||||
def step_start(self, step_text, callback=None):
|
||||
def step_start(self, step_text, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This triggers a step start
|
||||
|
||||
Args:
|
||||
@ -1131,7 +1135,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(step_text, MSG_TYPE.MSG_TYPE_STEP_START)
|
||||
|
||||
def step_end(self, step_text, callback=None):
|
||||
def step_end(self, step_text, status=True, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This triggers a step end
|
||||
|
||||
Args:
|
||||
@ -1139,9 +1143,9 @@ class APScript(StateMachine):
|
||||
callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the step end to. Defaults to None.
|
||||
"""
|
||||
if callback:
|
||||
callback(step_text, MSG_TYPE.MSG_TYPE_STEP_END)
|
||||
callback(step_text, MSG_TYPE.MSG_TYPE_STEP_END, {'status':status})
|
||||
|
||||
def step(self, step_text, callback=None):
|
||||
def step(self, step_text, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This triggers a step information
|
||||
|
||||
Args:
|
||||
@ -1151,7 +1155,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(step_text, MSG_TYPE.MSG_TYPE_STEP)
|
||||
|
||||
def exception(self, ex, callback=None):
|
||||
def exception(self, ex, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends exception to the client
|
||||
|
||||
Args:
|
||||
@ -1161,7 +1165,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(str(ex), MSG_TYPE.MSG_TYPE_EXCEPTION)
|
||||
|
||||
def warning(self, warning:str, callback=None):
|
||||
def warning(self, warning:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends exception to the client
|
||||
|
||||
Args:
|
||||
@ -1171,7 +1175,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(warning, MSG_TYPE.MSG_TYPE_EXCEPTION)
|
||||
|
||||
def info(self, info:str, callback=None):
|
||||
def info(self, info:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends exception to the client
|
||||
|
||||
Args:
|
||||
@ -1181,7 +1185,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(info, MSG_TYPE.MSG_TYPE_INFO)
|
||||
|
||||
def json(self, json_infos:dict, callback=None):
|
||||
def json(self, json_infos:dict, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends json data to front end
|
||||
|
||||
Args:
|
||||
@ -1191,7 +1195,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(json.dumps(json_infos), MSG_TYPE.MSG_TYPE_JSON_INFOS)
|
||||
|
||||
def ui(self, html_ui:str, callback=None):
|
||||
def ui(self, html_ui:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends ui elements to front end
|
||||
|
||||
Args:
|
||||
@ -1201,7 +1205,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(html_ui, MSG_TYPE.MSG_TYPE_UI)
|
||||
|
||||
def code(self, code:str, callback=None):
|
||||
def code(self, code:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends code to front end
|
||||
|
||||
Args:
|
||||
@ -1211,7 +1215,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(code, MSG_TYPE.MSG_TYPE_CODE)
|
||||
|
||||
def full(self, full_text:str, callback=None):
|
||||
def full(self, full_text:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends full text to front end
|
||||
|
||||
Args:
|
||||
@ -1224,7 +1228,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(full_text, MSG_TYPE.MSG_TYPE_FULL)
|
||||
|
||||
def full_invisible_to_ai(self, full_text:str, callback=None):
|
||||
def full_invisible_to_ai(self, full_text:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends full text to front end (INVISIBLE to AI)
|
||||
|
||||
Args:
|
||||
@ -1237,7 +1241,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(full_text, MSG_TYPE.MSG_TYPE_FULL_INVISIBLE_TO_AI)
|
||||
|
||||
def full_invisible_to_user(self, full_text:str, callback=None):
|
||||
def full_invisible_to_user(self, full_text:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends full text to front end (INVISIBLE to user)
|
||||
|
||||
Args:
|
||||
@ -1251,7 +1255,7 @@ class APScript(StateMachine):
|
||||
callback(full_text, MSG_TYPE.MSG_TYPE_FULL_INVISIBLE_TO_USER)
|
||||
|
||||
|
||||
def info(self, info_text:str, callback=None):
|
||||
def info(self, info_text:str, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends info text to front end
|
||||
|
||||
Args:
|
||||
@ -1264,7 +1268,7 @@ class APScript(StateMachine):
|
||||
if callback:
|
||||
callback(info_text, MSG_TYPE.MSG_TYPE_FULL)
|
||||
|
||||
def step_progress(self, progress:float, callback=None):
|
||||
def step_progress(self, step_text:str, progress:float, callback: Callable[[str, int, dict], bool]=None):
|
||||
"""This sends step rogress to front end
|
||||
|
||||
Args:
|
||||
@ -1275,8 +1279,16 @@ class APScript(StateMachine):
|
||||
callback = self.callback
|
||||
|
||||
if callback:
|
||||
callback(str(progress), MSG_TYPE.MSG_TYPE_STEP_PROGRESS)
|
||||
|
||||
callback(step_text, MSG_TYPE.MSG_TYPE_STEP_PROGRESS, {'progress':progress})
|
||||
|
||||
#Helper method to convert outputs path to url
|
||||
def path2url(file):
|
||||
file = str(file).replace("\\","/")
|
||||
pth = file.split('/')
|
||||
idx = pth.index("outputs")
|
||||
pth = "/".join(pth[idx:])
|
||||
file_path = f"![](/{pth})\n"
|
||||
return file_path
|
||||
|
||||
# ===========================================================
|
||||
class AIPersonalityInstaller:
|
||||
|
@ -1,51 +1,85 @@
|
||||
from lollms.personality import APScript
|
||||
from lollms.helpers import ASCIIColors, trace_exception
|
||||
|
||||
from lollms.paths import LollmsPaths
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
class TFIDFLoader:
|
||||
@staticmethod
|
||||
def create_vectorizer_from_dict(tfidf_info):
|
||||
vectorizer = TfidfVectorizer(**tfidf_info['params'])
|
||||
vectorizer.vocabulary_ = tfidf_info['vocabulary']
|
||||
vectorizer.idf_ = [tfidf_info['idf_values'][feature] for feature in vectorizer.get_feature_names()]
|
||||
return vectorizer
|
||||
|
||||
@staticmethod
|
||||
def create_dict_from_vectorizer(vectorizer):
|
||||
tfidf_info = {
|
||||
"vocabulary": vectorizer.vocabulary_,
|
||||
"idf_values": dict(zip(vectorizer.get_feature_names(), vectorizer.idf_)),
|
||||
"params": vectorizer.get_params()
|
||||
}
|
||||
return tfidf_info
|
||||
class TextVectorizer:
|
||||
def __init__(self, processor):
|
||||
def __init__(
|
||||
self,
|
||||
vectorization_method, # supported "model_embedding" or "ftidf_vectorizer"
|
||||
model=None, #needed in case of using model_embedding
|
||||
database_path=None,
|
||||
save_db=False,
|
||||
visualize_data_at_startup=False,
|
||||
visualize_data_at_add_file=False,
|
||||
visualize_data_at_generate=False,
|
||||
data_visualization_method="PCA",
|
||||
database_dict=None
|
||||
):
|
||||
|
||||
self.processor:APScript = processor
|
||||
self.personality = self.processor.personality
|
||||
self.model = self.personality.model
|
||||
self.personality_config = self.processor.personality_config
|
||||
self.lollms_paths = self.personality.lollms_paths
|
||||
self.embeddings = {}
|
||||
self.texts = {}
|
||||
self.ready = False
|
||||
self.vectorizer = None
|
||||
|
||||
self.database_file = Path(self.lollms_paths.personal_data_path/self.personality_config["database_path"])
|
||||
self.vectorization_method = vectorization_method
|
||||
self.save_db = save_db
|
||||
self.model = model
|
||||
self.database_file = database_path
|
||||
|
||||
self.visualize_data_at_startup=self.personality_config["visualize_data_at_startup"]
|
||||
self.visualize_data_at_add_file=self.personality_config["visualize_data_at_add_file"]
|
||||
self.visualize_data_at_generate=self.personality_config["visualize_data_at_generate"]
|
||||
self.visualize_data_at_startup=visualize_data_at_startup
|
||||
self.visualize_data_at_add_file=visualize_data_at_add_file
|
||||
self.visualize_data_at_generate=visualize_data_at_generate
|
||||
|
||||
if self.personality_config.vectorization_method=="model_embedding":
|
||||
try:
|
||||
if self.model.embed("hi")==None:
|
||||
self.personality_config.vectorization_method="ftidf_vectorizer"
|
||||
self.data_visualization_method = data_visualization_method
|
||||
|
||||
if database_dict is not None:
|
||||
self.chunks = []
|
||||
self.embeddings = database_dict["embeddings"]
|
||||
self.texts = database_dict["text"]
|
||||
self.infos = database_dict["infos"]
|
||||
self.ready = True
|
||||
self.vectorizer = database_dict["vectorizer"]
|
||||
else:
|
||||
self.chunks = []
|
||||
self.embeddings = {}
|
||||
self.texts = {}
|
||||
self.ready = False
|
||||
self.vectorizer = None
|
||||
|
||||
if vectorization_method=="model_embedding":
|
||||
try:
|
||||
if not self.model or self.model.embed("hi")==None: # test
|
||||
self.vectorization_method="ftidf_vectorizer"
|
||||
self.infos={
|
||||
"vectorization_method":"ftidf_vectorizer"
|
||||
}
|
||||
else:
|
||||
self.infos={
|
||||
"vectorization_method":"model_embedding"
|
||||
}
|
||||
except Exception as ex:
|
||||
ASCIIColors.error("Couldn't embed the text, so trying to use tfidf instead.")
|
||||
trace_exception(ex)
|
||||
self.infos={
|
||||
"vectorization_method":"ftidf_vectorizer"
|
||||
}
|
||||
else:
|
||||
self.infos={
|
||||
"vectorization_method":"model_embedding"
|
||||
}
|
||||
except Exception as ex:
|
||||
ASCIIColors.error("Couldn't embed the text, so trying to use tfidf instead.")
|
||||
trace_exception(ex)
|
||||
self.infos={
|
||||
"vectorization_method":"ftidf_vectorizer"
|
||||
}
|
||||
# Load previous state from the JSON file
|
||||
if self.personality_config.save_db:
|
||||
if self.save_db:
|
||||
if Path(self.database_file).exists():
|
||||
ASCIIColors.success(f"Database file found : {self.database_file}")
|
||||
self.load_from_json()
|
||||
@ -56,7 +90,7 @@ class TextVectorizer:
|
||||
ASCIIColors.info(f"No database file found : {self.database_file}")
|
||||
|
||||
|
||||
def show_document(self, query_text=None):
|
||||
def show_document(self, query_text=None, save_fig_path =None, show_interactive_form=False):
|
||||
import textwrap
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
@ -66,9 +100,8 @@ class TextVectorizer:
|
||||
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
import torch
|
||||
|
||||
if self.personality_config.data_visualization_method=="PCA":
|
||||
if self.data_visualization_method=="PCA":
|
||||
use_pca = True
|
||||
else:
|
||||
use_pca = False
|
||||
@ -80,6 +113,7 @@ class TextVectorizer:
|
||||
texts = list(self.texts.values())
|
||||
embeddings = self.embeddings
|
||||
emb = list(embeddings.values())
|
||||
ref = list(embeddings.keys())
|
||||
if len(emb)>=2:
|
||||
# Normalize embeddings
|
||||
emb = np.vstack(emb)
|
||||
@ -94,6 +128,7 @@ class TextVectorizer:
|
||||
|
||||
# Combine the query embedding with the document embeddings
|
||||
combined_embeddings = np.vstack((normalized_embeddings, query_normalized_embedding))
|
||||
ref.append("Quey_chunk_0")
|
||||
else:
|
||||
# Combine the query embedding with the document embeddings
|
||||
combined_embeddings = normalized_embeddings
|
||||
@ -113,13 +148,19 @@ class TextVectorizer:
|
||||
tsne = TSNE(n_components=2, perplexity=perplexity)
|
||||
embeddings_2d = tsne.fit_transform(combined_embeddings)
|
||||
|
||||
# Create a dictionary to map document paths to colors
|
||||
document_path_colors = {}
|
||||
for i, path in enumerate(ref):
|
||||
document_path = "_".join(path.split("_")[:-1]) # Extract the document path (excluding chunk and chunk number)
|
||||
if document_path not in document_path_colors:
|
||||
# Assign a new color to the document path if it's not in the dictionary
|
||||
document_path_colors[document_path] = sns.color_palette("hls", len(document_path_colors) + 1)[-1]
|
||||
|
||||
# Generate a list of colors for each data point based on the document path
|
||||
point_colors = [document_path_colors["_".join(path.split("_")[:-1])] for path in ref]
|
||||
|
||||
# Create a scatter plot using Seaborn
|
||||
if query_text is not None:
|
||||
sns.scatterplot(x=embeddings_2d[:-1, 0], y=embeddings_2d[:-1, 1]) # Plot document embeddings
|
||||
plt.scatter(embeddings_2d[-1, 0], embeddings_2d[-1, 1], color='red') # Plot query embedding
|
||||
else:
|
||||
sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1]) # Plot document embeddings
|
||||
sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1], hue=point_colors) # Plot document embeddings
|
||||
# Add labels to the scatter plot
|
||||
for i, (x, y) in enumerate(embeddings_2d[:-1]):
|
||||
plt.text(x, y, str(i), fontsize=8)
|
||||
@ -176,11 +217,12 @@ class TextVectorizer:
|
||||
|
||||
# Connect the click event handler to the figure
|
||||
plt.gcf().canvas.mpl_connect("button_press_event", on_click)
|
||||
plt.savefig(self.lollms_paths.personal_uploads_path / self.personality.personality_folder_name/ "db.png")
|
||||
plt.show()
|
||||
if save_fig_path:
|
||||
plt.savefig(save_fig_path)
|
||||
if show_interactive_form:
|
||||
plt.show()
|
||||
|
||||
def index_document(self, document_id, text, chunk_size, overlap_size, force_vectorize=False):
|
||||
|
||||
def add_document(self, document_id, text, chunk_size, overlap_size, force_vectorize=False):
|
||||
if document_id in self.embeddings and not force_vectorize:
|
||||
print(f"Document {document_id} already exists. Skipping vectorization.")
|
||||
return
|
||||
@ -188,15 +230,13 @@ class TextVectorizer:
|
||||
# Split tokens into sentences
|
||||
sentences = text.split('. ')
|
||||
def remove_empty_sentences(sentences):
|
||||
return [sentence for sentence in sentences if sentence.strip() != '']
|
||||
return [self.model.tokenize(sentence) for sentence in sentences if sentence.strip() != '']
|
||||
sentences = remove_empty_sentences(sentences)
|
||||
# Generate chunks with overlap and sentence boundaries
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
for i in range(len(sentences)):
|
||||
sentence = sentences[i]
|
||||
sentence_tokens = self.model.tokenize(sentence)
|
||||
|
||||
sentence_tokens = sentences[i]
|
||||
|
||||
# ASCIIColors.yellow(len(sentence_tokens))
|
||||
if len(current_chunk) + len(sentence_tokens) <= chunk_size:
|
||||
@ -204,45 +244,51 @@ class TextVectorizer:
|
||||
else:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
while len(sentence_tokens)>chunk_size:
|
||||
current_chunk = sentence_tokens[0:chunk_size]
|
||||
sentence_tokens = sentence_tokens[chunk_size:]
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = sentence_tokens
|
||||
|
||||
|
||||
current_chunk=[]
|
||||
for j in reversed(range(overlap_size)):
|
||||
current_chunk.extend(sentences[i-j-1])
|
||||
current_chunk.extend(sentence_tokens)
|
||||
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
if self.personality_config.vectorization_method=="ftidf_vectorizer":
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
chunk_id = f"{document_id}_chunk_{i + 1}"
|
||||
chunk_dict = {
|
||||
"chunk_id": chunk_id,
|
||||
"chunk_text": chunk_text
|
||||
}
|
||||
self.chunks.append(chunk_dict)
|
||||
|
||||
def index(self):
|
||||
if self.vectorization_method=="ftidf_vectorizer":
|
||||
self.vectorizer = TfidfVectorizer()
|
||||
#if self.personality.config.debug:
|
||||
#if self.debug:
|
||||
# ASCIIColors.yellow(','.join([len(chunk) for chunk in chunks]))
|
||||
data=[]
|
||||
for chunk in chunks:
|
||||
for chunk in self.chunks:
|
||||
try:
|
||||
data.append(self.model.detokenize(chunk).replace("<s>","").replace("</s>","") )
|
||||
data.append(self.model.detokenize(chunk["chunk_text"]).replace("<s>","").replace("</s>","") )
|
||||
except Exception as ex:
|
||||
print("oups")
|
||||
self.vectorizer.fit(data)
|
||||
|
||||
self.embeddings = {}
|
||||
# Generate embeddings for each chunk
|
||||
for i, chunk in enumerate(chunks):
|
||||
for i, chunk in enumerate(self.chunks):
|
||||
# Store chunk ID, embedding, and original text
|
||||
chunk_id = f"{document_id}_chunk_{i + 1}"
|
||||
chunk_id = chunk["chunk_id"]
|
||||
chunk_text = chunk["chunk_text"]
|
||||
try:
|
||||
self.texts[chunk_id] = self.model.detokenize(chunk[:chunk_size])
|
||||
if self.personality_config.vectorization_method=="ftidf_vectorizer":
|
||||
self.texts[chunk_id] = self.model.detokenize(chunk_text)
|
||||
if self.vectorization_method=="ftidf_vectorizer":
|
||||
self.embeddings[chunk_id] = self.vectorizer.transform([self.texts[chunk_id]]).toarray()
|
||||
else:
|
||||
self.embeddings[chunk_id] = self.model.embed(self.texts[chunk_id])
|
||||
except Exception as ex:
|
||||
print("oups")
|
||||
|
||||
if self.personality_config.save_db:
|
||||
if self.save_db:
|
||||
self.save_to_json()
|
||||
|
||||
self.ready = True
|
||||
@ -252,10 +298,14 @@ class TextVectorizer:
|
||||
|
||||
def embed_query(self, query_text):
|
||||
# Generate query embedding
|
||||
if self.personality_config.vectorization_method=="ftidf_vectorizer":
|
||||
if self.vectorization_method=="ftidf_vectorizer":
|
||||
query_embedding = self.vectorizer.transform([query_text]).toarray()
|
||||
else:
|
||||
query_embedding = self.model.embed(query_text)
|
||||
if query_embedding is None:
|
||||
ASCIIColors.warning("The model doesn't implement embedding extraction")
|
||||
self.vectorization_method="ftidf_vectorizer"
|
||||
query_embedding = self.vectorizer.transform([query_text]).toarray()
|
||||
|
||||
return query_embedding
|
||||
|
||||
@ -277,6 +327,18 @@ class TextVectorizer:
|
||||
|
||||
return texts, sorted_similarities
|
||||
|
||||
def toJson(self):
|
||||
state = {
|
||||
"embeddings": {str(k): v.tolist() if type(v)!=list else v for k, v in self.embeddings.items() },
|
||||
"texts": self.texts,
|
||||
"infos": self.infos,
|
||||
"vectorizer": TFIDFLoader.create_vectorizer_from_dict(self.vectorizer) if self.vectorization_method=="ftidf_vectorizer" else None
|
||||
}
|
||||
return state
|
||||
|
||||
def setVectorizer(self, vectorizer_dict:dict):
|
||||
self.vectorizer=TFIDFLoader.create_vectorizer_from_dict(vectorizer_dict)
|
||||
|
||||
def save_to_json(self):
|
||||
state = {
|
||||
"embeddings": {str(k): v.tolist() if type(v)!=list else v for k, v in self.embeddings.items() },
|
||||
@ -295,7 +357,7 @@ class TextVectorizer:
|
||||
self.texts = state["texts"]
|
||||
self.infos= state["infos"]
|
||||
self.ready = True
|
||||
if self.personality_config.vectorization_method=="ftidf_vectorizer":
|
||||
if self.vectorization_method=="ftidf_vectorizer":
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
data = list(self.texts.values())
|
||||
if len(data)>0:
|
||||
@ -304,11 +366,13 @@ class TextVectorizer:
|
||||
self.embeddings={}
|
||||
for k,v in self.texts.items():
|
||||
self.embeddings[k]= self.vectorizer.transform([v]).toarray()
|
||||
|
||||
|
||||
def clear_database(self):
|
||||
self.vectorizer=None
|
||||
self.embeddings = {}
|
||||
self.texts={}
|
||||
if self.personality_config.save_db:
|
||||
if self.save_db:
|
||||
self.save_to_json()
|
||||
|
||||
|
||||
|
@ -8,4 +8,9 @@ simple-websocket
|
||||
eventlet
|
||||
wget
|
||||
setuptools
|
||||
requests
|
||||
requests
|
||||
|
||||
matplotlib
|
||||
seaborn
|
||||
mplcursors
|
||||
scikit-learn
|
Loading…
Reference in New Issue
Block a user