diff --git a/examples/chat_forever/console.py b/examples/chat_forever/console.py index 0f72871..44314da 100644 --- a/examples/chat_forever/console.py +++ b/examples/chat_forever/console.py @@ -14,7 +14,7 @@ class MyConversation(Conversation): self.menu.main_menu() full_discussion += self.personality.user_message_prefix+prompt+self.personality.link_text full_discussion += self.personality.ai_message_prefix - def callback(text, type=None): + def callback(text, type=None, metadata:dict={}): print(text, end="") sys.stdout = sys.__stdout__ sys.stdout.flush() diff --git a/examples/simple_story/console.py b/examples/simple_story/console.py index b6f6fd2..a29a028 100644 --- a/examples/simple_story/console.py +++ b/examples/simple_story/console.py @@ -7,7 +7,7 @@ class MyConversation(Conversation): def start_conversation(self): prompt = "Once apon a time" - def callback(text, type=None): + def callback(text, type=None, metadata:dict={}): print(text, end="") sys.stdout = sys.__stdout__ sys.stdout.flush() diff --git a/lollms/apps/console/__init__.py b/lollms/apps/console/__init__.py index f685e0e..a5804cc 100644 --- a/lollms/apps/console/__init__.py +++ b/lollms/apps/console/__init__.py @@ -7,6 +7,7 @@ from lollms.paths import LollmsPaths from lollms.app import LollmsApplication from lollms.terminal import MainMenu +from typing import Callable from pathlib import Path import argparse import yaml @@ -133,7 +134,7 @@ Participating personalities: full_discussion = "" return full_discussion - def safe_generate(self, full_discussion:str, n_predict=None, callback=None): + def safe_generate(self, full_discussion:str, n_predict=None, callback: Callable[[str, int, dict], bool]=None): """safe_generate Args: @@ -236,7 +237,7 @@ Participating personalities: self.personality.ai_message_prefix ) - def callback(text, type:MSG_TYPE=None): + def callback(text, type:MSG_TYPE=None, metadata:dict={}): if type == MSG_TYPE.MSG_TYPE_CHUNK: # Replace stdout with the default stdout sys.stdout = sys.__stdout__ diff --git a/lollms/apps/server/__init__.py b/lollms/apps/server/__init__.py index ee1fa3d..f406971 100644 --- a/lollms/apps/server/__init__.py +++ b/lollms/apps/server/__init__.py @@ -12,6 +12,7 @@ from lollms.apps.console import MainMenu from lollms.paths import LollmsPaths from lollms.apps.console import MainMenu from lollms.app import LollmsApplication +from lollms.utilities import TextVectorizer from typing import List, Tuple import importlib from pathlib import Path @@ -20,6 +21,7 @@ import logging import yaml import copy import gc +import json def reset_all_installs(lollms_paths:LollmsPaths): ASCIIColors.info("Removeing all configuration files to force reinstall") ASCIIColors.info(f"Searching files from {lollms_paths.personal_configuration_path}") @@ -322,6 +324,96 @@ class LoLLMsServer(LollmsApplication): emit('personality_add_failed', {'success':False, 'error': error_message}, room=request.sid) + + @self.socketio.on('vectorize_text') + def vectorize_text(parameters:dict): + """Vectorizes text + + Args: + parameters (dict): contains + 'chunk_size': the maximum size of a text chunk (512 by default) + 'vectorization_method': can be either "model_embedding" or "ftidf_vectorizer" (default is "ftidf_vectorizer") + 'payloads': a list of dicts. each entry has the following format + { + "path": the path to the document + "text": the text of the document + }, + 'return_database': If true the vectorized database will be sent to the client (default is True) + 'database_path': the path to store the database (default is none) + + returns a dict + status: True if success and false if not + if you asked for the database to be sent back you will ahve those fields too: + embeddings: a dictionary containing the text chunks with their ids and embeddings + "texts": a dictionary of text chunks for each embedding (use index for correspondance) + "infos": extra information + "vectorizer": The vectorize if this is using tfidf or none if it uses model + + """ + vectorization_method = parameters.get('vectorization_method',"ftidf_vectorizer") + chunk_size = parameters.get("chunk_size",512) + payloads = parameters["payloads"] + database_path = parameters.get("database_path",None) + return_database = parameters.get("return_database",True) + if database_path is None and return_database is None: + ASCIIColors.warning("Vectorization should either ask to save the database or to recover it. You didn't ask for any one!") + emit('vectorized_db',{"status":False, "error":"Vectorization should either ask to save the database or to recover it. You didn't ask for any one!"}) + return + tv = TextVectorizer(vectorization_method, self.model) + for payload in payloads: + tv.add_document(payload["path"],payload["text"],chunk_size=chunk_size) + json_db = tv.toJson() + if return_database: + emit('vectorized_db',{**{"status":True}, **json_db}) + else: + emit('vectorized_db',{"status":True}) + with open(database_path, "w") as file: + json.dump(json_db, file, indent=4) + + + @self.socketio.on('query_database') + def query_database(parameters:dict): + """queries a database + + Args: + parameters (dict): contains + 'vectorization_method': can be either "model_embedding" or "ftidf_vectorizer" + 'database': a list of dicts. each entry has the following format + { + embeddings: a dictionary containing the text chunks with their ids and embeddings + "texts": a dictionary of text chunks for each embedding (use index for correspondance) + "infos": extra information + "vectorizer": The vectorize if this is using tfidf or none if it uses model + } + 'database_path': If supplied, the database is loaded from a path + 'query': a query to search in the database + """ + vectorization_method = parameters['vectorization_method'] + database = parameters.get("database",None) + query = parameters.get("query",None) + if query is None: + ASCIIColors.error("No query given!") + emit('vector_db_query',{"status":False, "error":"Please supply a query"}) + return + + if database is None: + database_path = parameters.get("database_path",None) + if database_path is None: + ASCIIColors.error("No database given!") + emit('vector_db_query',{"status":False, "error":"You did not supply a database file nor a database content"}) + return + else: + with open(database_path, "r") as file: + database = json.load(file) + + tv = TextVectorizer(vectorization_method, self.model, database_dict=database) + docs, sorted_similarities = tv.recover_text(tv.embed_query(query)) + emit('vectorized_db',{ + "chunks":docs, + "refs":sorted_similarities + }) + + @self.socketio.on('list_active_personalities') def handle_list_active_personalities(): personality_names = [p.name for p in self.personalities] @@ -394,7 +486,7 @@ class LoLLMsServer(LollmsApplication): if personality_id==-1: # Raw text generation self.answer = {"full_text":""} - def callback(text, message_type: MSG_TYPE): + def callback(text, message_type: MSG_TYPE, metadata:dict={}): if message_type == MSG_TYPE.MSG_TYPE_CHUNK: ASCIIColors.success(f"generated:{len(self.answer['full_text'].split())} words", end='\r') self.answer["full_text"] = self.answer["full_text"] + text @@ -467,7 +559,7 @@ class LoLLMsServer(LollmsApplication): full_discussion = personality.personality_conditioning + ''.join(full_discussion_blocks) - def callback(text, message_type: MSG_TYPE): + def callback(text, message_type: MSG_TYPE, metadata:dict={}): if message_type == MSG_TYPE.MSG_TYPE_CHUNK: self.answer["full_text"] = self.answer["full_text"] + text self.socketio.emit('text_chunk', {'chunk': text}, room=client_id) diff --git a/lollms/apps/settings/__init__.py b/lollms/apps/settings/__init__.py index 4285c79..338418f 100644 --- a/lollms/apps/settings/__init__.py +++ b/lollms/apps/settings/__init__.py @@ -120,7 +120,7 @@ Participating personalities: full_discussion = "" return full_discussion - def safe_generate(self, full_discussion:str, n_predict=None, callback=None): + def safe_generate(self, full_discussion:str, n_predict=None, callback: Callable[[str, int, dict], bool]=None): """safe_generate Args: diff --git a/lollms/binding.py b/lollms/binding.py index 3a7c1ac..5f9d30b 100644 --- a/lollms/binding.py +++ b/lollms/binding.py @@ -195,9 +195,9 @@ class LLMBinding: def generate(self, - prompt:str, + prompt:str, n_predict: int = 128, - callback: Callable[[str], None] = None, + callback: Callable[[str, int, dict], bool] = None, verbose: bool = False, **gpt_params ): """Generates text out of a prompt @@ -206,7 +206,7 @@ class LLMBinding: Args: prompt (str): The prompt to use for generation n_predict (int, optional): Number of tokens to prodict. Defaults to 128. - callback (Callable[[str], None], optional): A callback function that is called everytime a new text element is generated. Defaults to None. + callback (Callable[[str, int, dict], None], optional): A callback function that is called everytime a new text element is generated. Defaults to None. verbose (bool, optional): If true, the code will spit many informations about the generation process. Defaults to False. """ pass diff --git a/lollms/personality.py b/lollms/personality.py index 1559058..9fd7abe 100644 --- a/lollms/personality.py +++ b/lollms/personality.py @@ -17,6 +17,7 @@ import subprocess import yaml from lollms.helpers import ASCIIColors from lollms.types import MSG_TYPE +from typing import Callable import json @@ -54,7 +55,8 @@ class AIPersonality: model:LLMBinding=None, run_scripts=True, is_relative_path=True, - installation_option:InstallOption=InstallOption.INSTALL_IF_NECESSARY + installation_option:InstallOption=InstallOption.INSTALL_IF_NECESSARY, + callback: Callable[[str, int, dict], bool]=None ): """ Initialize an AIPersonality instance. @@ -68,6 +70,7 @@ class AIPersonality: self.lollms_paths = lollms_paths self.model = model self.config = config + self.callback = callback self.files = [] @@ -248,7 +251,7 @@ Date: {{date}} module = importlib.util.module_from_spec(module_spec) module_spec.loader.exec_module(module) if hasattr(module, "Processor"): - self._processor = module.Processor(self) + self._processor = module.Processor(self, callback=self.callback) else: self._processor = None else: @@ -881,7 +884,7 @@ class StateMachine: - def process_state(self, command, full_context, callback=None): + def process_state(self, command, full_context, callback: Callable[[str, int, dict], bool]=None): """ Process the given command based on the current state. @@ -922,7 +925,8 @@ class APScript(StateMachine): self, personality :AIPersonality, personality_config :TypedConfig, - states_dict :dict = {} + states_dict :dict = {}, + callback = None ) -> None: super().__init__(states_dict) self.files=[] @@ -932,6 +936,7 @@ class APScript(StateMachine): self.configuration_file_path = self.personality.lollms_paths.personal_configuration_path/f"personality_{self.personality.personality_folder_name}.yaml" self.personality_config.config.file_path = self.configuration_file_path + self.callback = callback # Installation if (not self.configuration_file_path.exists() or self.installation_option==InstallOption.FORCE_INSTALL) and self.installation_option!=InstallOption.NEVER_INSTALL: self.install() @@ -990,8 +995,7 @@ class APScript(StateMachine): else: ASCIIColors.error("Pytorch installed successfully!!") - def add_file(self, path, callback=None): - self.callback=callback + def add_file(self, path): self.files.append(path) return True @@ -1104,7 +1108,7 @@ class APScript(StateMachine): else: return False - def run_workflow(self, prompt:str, previous_discussion_text:str="", callback=None): + def run_workflow(self, prompt:str, previous_discussion_text:str="", callback: Callable[[str, int, dict], bool]=None): """ Runs the workflow for processing the model input and output. @@ -1121,7 +1125,7 @@ class APScript(StateMachine): """ return None - def step_start(self, step_text, callback=None): + def step_start(self, step_text, callback: Callable[[str, int, dict], bool]=None): """This triggers a step start Args: @@ -1131,7 +1135,7 @@ class APScript(StateMachine): if callback: callback(step_text, MSG_TYPE.MSG_TYPE_STEP_START) - def step_end(self, step_text, callback=None): + def step_end(self, step_text, status=True, callback: Callable[[str, int, dict], bool]=None): """This triggers a step end Args: @@ -1139,9 +1143,9 @@ class APScript(StateMachine): callback (callable, optional): A callable with this signature (str, MSG_TYPE) to send the step end to. Defaults to None. """ if callback: - callback(step_text, MSG_TYPE.MSG_TYPE_STEP_END) + callback(step_text, MSG_TYPE.MSG_TYPE_STEP_END, {'status':status}) - def step(self, step_text, callback=None): + def step(self, step_text, callback: Callable[[str, int, dict], bool]=None): """This triggers a step information Args: @@ -1151,7 +1155,7 @@ class APScript(StateMachine): if callback: callback(step_text, MSG_TYPE.MSG_TYPE_STEP) - def exception(self, ex, callback=None): + def exception(self, ex, callback: Callable[[str, int, dict], bool]=None): """This sends exception to the client Args: @@ -1161,7 +1165,7 @@ class APScript(StateMachine): if callback: callback(str(ex), MSG_TYPE.MSG_TYPE_EXCEPTION) - def warning(self, warning:str, callback=None): + def warning(self, warning:str, callback: Callable[[str, int, dict], bool]=None): """This sends exception to the client Args: @@ -1171,7 +1175,7 @@ class APScript(StateMachine): if callback: callback(warning, MSG_TYPE.MSG_TYPE_EXCEPTION) - def info(self, info:str, callback=None): + def info(self, info:str, callback: Callable[[str, int, dict], bool]=None): """This sends exception to the client Args: @@ -1181,7 +1185,7 @@ class APScript(StateMachine): if callback: callback(info, MSG_TYPE.MSG_TYPE_INFO) - def json(self, json_infos:dict, callback=None): + def json(self, json_infos:dict, callback: Callable[[str, int, dict], bool]=None): """This sends json data to front end Args: @@ -1191,7 +1195,7 @@ class APScript(StateMachine): if callback: callback(json.dumps(json_infos), MSG_TYPE.MSG_TYPE_JSON_INFOS) - def ui(self, html_ui:str, callback=None): + def ui(self, html_ui:str, callback: Callable[[str, int, dict], bool]=None): """This sends ui elements to front end Args: @@ -1201,7 +1205,7 @@ class APScript(StateMachine): if callback: callback(html_ui, MSG_TYPE.MSG_TYPE_UI) - def code(self, code:str, callback=None): + def code(self, code:str, callback: Callable[[str, int, dict], bool]=None): """This sends code to front end Args: @@ -1211,7 +1215,7 @@ class APScript(StateMachine): if callback: callback(code, MSG_TYPE.MSG_TYPE_CODE) - def full(self, full_text:str, callback=None): + def full(self, full_text:str, callback: Callable[[str, int, dict], bool]=None): """This sends full text to front end Args: @@ -1224,7 +1228,7 @@ class APScript(StateMachine): if callback: callback(full_text, MSG_TYPE.MSG_TYPE_FULL) - def full_invisible_to_ai(self, full_text:str, callback=None): + def full_invisible_to_ai(self, full_text:str, callback: Callable[[str, int, dict], bool]=None): """This sends full text to front end (INVISIBLE to AI) Args: @@ -1237,7 +1241,7 @@ class APScript(StateMachine): if callback: callback(full_text, MSG_TYPE.MSG_TYPE_FULL_INVISIBLE_TO_AI) - def full_invisible_to_user(self, full_text:str, callback=None): + def full_invisible_to_user(self, full_text:str, callback: Callable[[str, int, dict], bool]=None): """This sends full text to front end (INVISIBLE to user) Args: @@ -1251,7 +1255,7 @@ class APScript(StateMachine): callback(full_text, MSG_TYPE.MSG_TYPE_FULL_INVISIBLE_TO_USER) - def info(self, info_text:str, callback=None): + def info(self, info_text:str, callback: Callable[[str, int, dict], bool]=None): """This sends info text to front end Args: @@ -1264,7 +1268,7 @@ class APScript(StateMachine): if callback: callback(info_text, MSG_TYPE.MSG_TYPE_FULL) - def step_progress(self, progress:float, callback=None): + def step_progress(self, step_text:str, progress:float, callback: Callable[[str, int, dict], bool]=None): """This sends step rogress to front end Args: @@ -1275,8 +1279,16 @@ class APScript(StateMachine): callback = self.callback if callback: - callback(str(progress), MSG_TYPE.MSG_TYPE_STEP_PROGRESS) - + callback(step_text, MSG_TYPE.MSG_TYPE_STEP_PROGRESS, {'progress':progress}) + + #Helper method to convert outputs path to url + def path2url(file): + file = str(file).replace("\\","/") + pth = file.split('/') + idx = pth.index("outputs") + pth = "/".join(pth[idx:]) + file_path = f"![](/{pth})\n" + return file_path # =========================================================== class AIPersonalityInstaller: diff --git a/lollms/utilities.py b/lollms/utilities.py index 6b02642..04858e4 100644 --- a/lollms/utilities.py +++ b/lollms/utilities.py @@ -1,51 +1,85 @@ from lollms.personality import APScript from lollms.helpers import ASCIIColors, trace_exception - +from lollms.paths import LollmsPaths +from sklearn.feature_extraction.text import TfidfVectorizer import numpy as np -import json from pathlib import Path -import numpy as np import json +class TFIDFLoader: + @staticmethod + def create_vectorizer_from_dict(tfidf_info): + vectorizer = TfidfVectorizer(**tfidf_info['params']) + vectorizer.vocabulary_ = tfidf_info['vocabulary'] + vectorizer.idf_ = [tfidf_info['idf_values'][feature] for feature in vectorizer.get_feature_names()] + return vectorizer + @staticmethod + def create_dict_from_vectorizer(vectorizer): + tfidf_info = { + "vocabulary": vectorizer.vocabulary_, + "idf_values": dict(zip(vectorizer.get_feature_names(), vectorizer.idf_)), + "params": vectorizer.get_params() + } + return tfidf_info class TextVectorizer: - def __init__(self, processor): + def __init__( + self, + vectorization_method, # supported "model_embedding" or "ftidf_vectorizer" + model=None, #needed in case of using model_embedding + database_path=None, + save_db=False, + visualize_data_at_startup=False, + visualize_data_at_add_file=False, + visualize_data_at_generate=False, + data_visualization_method="PCA", + database_dict=None + ): - self.processor:APScript = processor - self.personality = self.processor.personality - self.model = self.personality.model - self.personality_config = self.processor.personality_config - self.lollms_paths = self.personality.lollms_paths - self.embeddings = {} - self.texts = {} - self.ready = False - self.vectorizer = None - - self.database_file = Path(self.lollms_paths.personal_data_path/self.personality_config["database_path"]) + self.vectorization_method = vectorization_method + self.save_db = save_db + self.model = model + self.database_file = database_path - self.visualize_data_at_startup=self.personality_config["visualize_data_at_startup"] - self.visualize_data_at_add_file=self.personality_config["visualize_data_at_add_file"] - self.visualize_data_at_generate=self.personality_config["visualize_data_at_generate"] + self.visualize_data_at_startup=visualize_data_at_startup + self.visualize_data_at_add_file=visualize_data_at_add_file + self.visualize_data_at_generate=visualize_data_at_generate - if self.personality_config.vectorization_method=="model_embedding": - try: - if self.model.embed("hi")==None: - self.personality_config.vectorization_method="ftidf_vectorizer" + self.data_visualization_method = data_visualization_method + + if database_dict is not None: + self.chunks = [] + self.embeddings = database_dict["embeddings"] + self.texts = database_dict["text"] + self.infos = database_dict["infos"] + self.ready = True + self.vectorizer = database_dict["vectorizer"] + else: + self.chunks = [] + self.embeddings = {} + self.texts = {} + self.ready = False + self.vectorizer = None + + if vectorization_method=="model_embedding": + try: + if not self.model or self.model.embed("hi")==None: # test + self.vectorization_method="ftidf_vectorizer" + self.infos={ + "vectorization_method":"ftidf_vectorizer" + } + else: + self.infos={ + "vectorization_method":"model_embedding" + } + except Exception as ex: + ASCIIColors.error("Couldn't embed the text, so trying to use tfidf instead.") + trace_exception(ex) self.infos={ "vectorization_method":"ftidf_vectorizer" } - else: - self.infos={ - "vectorization_method":"model_embedding" - } - except Exception as ex: - ASCIIColors.error("Couldn't embed the text, so trying to use tfidf instead.") - trace_exception(ex) - self.infos={ - "vectorization_method":"ftidf_vectorizer" - } # Load previous state from the JSON file - if self.personality_config.save_db: + if self.save_db: if Path(self.database_file).exists(): ASCIIColors.success(f"Database file found : {self.database_file}") self.load_from_json() @@ -56,7 +90,7 @@ class TextVectorizer: ASCIIColors.info(f"No database file found : {self.database_file}") - def show_document(self, query_text=None): + def show_document(self, query_text=None, save_fig_path =None, show_interactive_form=False): import textwrap import seaborn as sns import matplotlib.pyplot as plt @@ -66,9 +100,8 @@ class TextVectorizer: from sklearn.manifold import TSNE from sklearn.decomposition import PCA - import torch - if self.personality_config.data_visualization_method=="PCA": + if self.data_visualization_method=="PCA": use_pca = True else: use_pca = False @@ -80,6 +113,7 @@ class TextVectorizer: texts = list(self.texts.values()) embeddings = self.embeddings emb = list(embeddings.values()) + ref = list(embeddings.keys()) if len(emb)>=2: # Normalize embeddings emb = np.vstack(emb) @@ -94,6 +128,7 @@ class TextVectorizer: # Combine the query embedding with the document embeddings combined_embeddings = np.vstack((normalized_embeddings, query_normalized_embedding)) + ref.append("Quey_chunk_0") else: # Combine the query embedding with the document embeddings combined_embeddings = normalized_embeddings @@ -113,13 +148,19 @@ class TextVectorizer: tsne = TSNE(n_components=2, perplexity=perplexity) embeddings_2d = tsne.fit_transform(combined_embeddings) + # Create a dictionary to map document paths to colors + document_path_colors = {} + for i, path in enumerate(ref): + document_path = "_".join(path.split("_")[:-1]) # Extract the document path (excluding chunk and chunk number) + if document_path not in document_path_colors: + # Assign a new color to the document path if it's not in the dictionary + document_path_colors[document_path] = sns.color_palette("hls", len(document_path_colors) + 1)[-1] + + # Generate a list of colors for each data point based on the document path + point_colors = [document_path_colors["_".join(path.split("_")[:-1])] for path in ref] # Create a scatter plot using Seaborn - if query_text is not None: - sns.scatterplot(x=embeddings_2d[:-1, 0], y=embeddings_2d[:-1, 1]) # Plot document embeddings - plt.scatter(embeddings_2d[-1, 0], embeddings_2d[-1, 1], color='red') # Plot query embedding - else: - sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1]) # Plot document embeddings + sns.scatterplot(x=embeddings_2d[:, 0], y=embeddings_2d[:, 1], hue=point_colors) # Plot document embeddings # Add labels to the scatter plot for i, (x, y) in enumerate(embeddings_2d[:-1]): plt.text(x, y, str(i), fontsize=8) @@ -176,11 +217,12 @@ class TextVectorizer: # Connect the click event handler to the figure plt.gcf().canvas.mpl_connect("button_press_event", on_click) - plt.savefig(self.lollms_paths.personal_uploads_path / self.personality.personality_folder_name/ "db.png") - plt.show() + if save_fig_path: + plt.savefig(save_fig_path) + if show_interactive_form: + plt.show() - def index_document(self, document_id, text, chunk_size, overlap_size, force_vectorize=False): - + def add_document(self, document_id, text, chunk_size, overlap_size, force_vectorize=False): if document_id in self.embeddings and not force_vectorize: print(f"Document {document_id} already exists. Skipping vectorization.") return @@ -188,15 +230,13 @@ class TextVectorizer: # Split tokens into sentences sentences = text.split('. ') def remove_empty_sentences(sentences): - return [sentence for sentence in sentences if sentence.strip() != ''] + return [self.model.tokenize(sentence) for sentence in sentences if sentence.strip() != ''] sentences = remove_empty_sentences(sentences) # Generate chunks with overlap and sentence boundaries chunks = [] current_chunk = [] for i in range(len(sentences)): - sentence = sentences[i] - sentence_tokens = self.model.tokenize(sentence) - + sentence_tokens = sentences[i] # ASCIIColors.yellow(len(sentence_tokens)) if len(current_chunk) + len(sentence_tokens) <= chunk_size: @@ -204,45 +244,51 @@ class TextVectorizer: else: if current_chunk: chunks.append(current_chunk) - - while len(sentence_tokens)>chunk_size: - current_chunk = sentence_tokens[0:chunk_size] - sentence_tokens = sentence_tokens[chunk_size:] - chunks.append(current_chunk) - current_chunk = sentence_tokens - + + current_chunk=[] + for j in reversed(range(overlap_size)): + current_chunk.extend(sentences[i-j-1]) + current_chunk.extend(sentence_tokens) + if current_chunk: - chunks.append(current_chunk) - - if self.personality_config.vectorization_method=="ftidf_vectorizer": - from sklearn.feature_extraction.text import TfidfVectorizer + for i, chunk_text in enumerate(chunks): + chunk_id = f"{document_id}_chunk_{i + 1}" + chunk_dict = { + "chunk_id": chunk_id, + "chunk_text": chunk_text + } + self.chunks.append(chunk_dict) + + def index(self): + if self.vectorization_method=="ftidf_vectorizer": self.vectorizer = TfidfVectorizer() - #if self.personality.config.debug: + #if self.debug: # ASCIIColors.yellow(','.join([len(chunk) for chunk in chunks])) data=[] - for chunk in chunks: + for chunk in self.chunks: try: - data.append(self.model.detokenize(chunk).replace("","").replace("","") ) + data.append(self.model.detokenize(chunk["chunk_text"]).replace("","").replace("","") ) except Exception as ex: print("oups") self.vectorizer.fit(data) self.embeddings = {} # Generate embeddings for each chunk - for i, chunk in enumerate(chunks): + for i, chunk in enumerate(self.chunks): # Store chunk ID, embedding, and original text - chunk_id = f"{document_id}_chunk_{i + 1}" + chunk_id = chunk["chunk_id"] + chunk_text = chunk["chunk_text"] try: - self.texts[chunk_id] = self.model.detokenize(chunk[:chunk_size]) - if self.personality_config.vectorization_method=="ftidf_vectorizer": + self.texts[chunk_id] = self.model.detokenize(chunk_text) + if self.vectorization_method=="ftidf_vectorizer": self.embeddings[chunk_id] = self.vectorizer.transform([self.texts[chunk_id]]).toarray() else: self.embeddings[chunk_id] = self.model.embed(self.texts[chunk_id]) except Exception as ex: print("oups") - if self.personality_config.save_db: + if self.save_db: self.save_to_json() self.ready = True @@ -252,10 +298,14 @@ class TextVectorizer: def embed_query(self, query_text): # Generate query embedding - if self.personality_config.vectorization_method=="ftidf_vectorizer": + if self.vectorization_method=="ftidf_vectorizer": query_embedding = self.vectorizer.transform([query_text]).toarray() else: query_embedding = self.model.embed(query_text) + if query_embedding is None: + ASCIIColors.warning("The model doesn't implement embedding extraction") + self.vectorization_method="ftidf_vectorizer" + query_embedding = self.vectorizer.transform([query_text]).toarray() return query_embedding @@ -277,6 +327,18 @@ class TextVectorizer: return texts, sorted_similarities + def toJson(self): + state = { + "embeddings": {str(k): v.tolist() if type(v)!=list else v for k, v in self.embeddings.items() }, + "texts": self.texts, + "infos": self.infos, + "vectorizer": TFIDFLoader.create_vectorizer_from_dict(self.vectorizer) if self.vectorization_method=="ftidf_vectorizer" else None + } + return state + + def setVectorizer(self, vectorizer_dict:dict): + self.vectorizer=TFIDFLoader.create_vectorizer_from_dict(vectorizer_dict) + def save_to_json(self): state = { "embeddings": {str(k): v.tolist() if type(v)!=list else v for k, v in self.embeddings.items() }, @@ -295,7 +357,7 @@ class TextVectorizer: self.texts = state["texts"] self.infos= state["infos"] self.ready = True - if self.personality_config.vectorization_method=="ftidf_vectorizer": + if self.vectorization_method=="ftidf_vectorizer": from sklearn.feature_extraction.text import TfidfVectorizer data = list(self.texts.values()) if len(data)>0: @@ -304,11 +366,13 @@ class TextVectorizer: self.embeddings={} for k,v in self.texts.items(): self.embeddings[k]= self.vectorizer.transform([v]).toarray() + + def clear_database(self): self.vectorizer=None self.embeddings = {} self.texts={} - if self.personality_config.save_db: + if self.save_db: self.save_to_json() diff --git a/requirements.txt b/requirements.txt index d5c4a9e..6811d00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,9 @@ simple-websocket eventlet wget setuptools -requests \ No newline at end of file +requests + +matplotlib +seaborn +mplcursors +scikit-learn \ No newline at end of file diff --git a/setup.py b/setup.py index 076be94..ab81183 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def get_all_files(path): setuptools.setup( name="lollms", - version="2.1.56", + version="2.1.59", author="Saifeddine ALOUI", author_email="aloui.saifeddine@gmail.com", description="A python library for AI personality definition",