moved skills database to the new system

This commit is contained in:
Saifeddine ALOUI 2024-07-10 01:34:23 +02:00
parent 0e176402e3
commit 488229b912
2 changed files with 25 additions and 20 deletions

View File

@ -263,7 +263,7 @@ class LollmsApplication(LoLLMsCom):
def _generate_text(self, prompt): def _generate_text(self, prompt):
max_tokens = self.config.ctx_size - self.model.get_nb_tokens(prompt) max_tokens = min(self.config.ctx_size - self.model.get_nb_tokens(prompt),self.config.max_n_predict)
generated_text = self.model.generate(prompt, max_tokens) generated_text = self.model.generate(prompt, max_tokens)
return generated_text.strip() return generated_text.strip()

View File

@ -1,11 +1,15 @@
import sqlite3 import sqlite3
from safe_store.text_vectorizer import TextVectorizer, VectorizationMethod, VisualizationMethod from lollmsvectordb import VectorDatabase, BERTVectorizer
from lollmsvectordb.lollms_tokenizers.tiktoken_tokenizer import TikTokenTokenizer
import numpy as np import numpy as np
from ascii_colors import ASCIIColors
class SkillsLibrary: class SkillsLibrary:
def __init__(self, db_path): def __init__(self, db_path, model_name: str = 'bert-base-nli-mean-tokens', chunk_size:int=512, overlap:int=0, n_neighbors:int=5):
self.db_path =db_path self.db_path =db_path
self._initialize_db() self._initialize_db()
self.vectorizer = VectorDatabase("", BERTVectorizer(), TikTokenTokenizer(),chunk_size, overlap, n_neighbors)
ASCIIColors.green("Vecorizer ready")
def _initialize_db(self): def _initialize_db(self):
@ -121,37 +125,38 @@ class SkillsLibrary:
return res return res
def query_vector_db(self, query_, top_k=3, max_dist=1000): def query_vector_db(self, query_, top_k=3, max_dist=1000):
vectorizer = TextVectorizer(VectorizationMethod.TFIDF_VECTORIZER)
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
cursor = conn.cursor() cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression. # Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection. # Ensure text is safely escaped to avoid SQL injection.
query = "SELECT id, title FROM skills_library" query = "SELECT id, title, content FROM skills_library"
cursor.execute(query) cursor.execute(query)
res = cursor.fetchall() res = cursor.fetchall()
cursor.close() cursor.close()
conn.close() conn.close()
skills = [] skills = []
skill_titles = []
if len(res)>0: if len(res)>0:
for entry in res: for entry in res:
vectorizer.add_document(entry[0],entry[1]) self.vectorizer.add_document(entry[0],"Title:"+entry[1]+"\n"+entry[2])
vectorizer.index() self.vectorizer.build_index()
skill_titles, sorted_similarities, document_ids = vectorizer.recover_text(query_, top_k) chunks = self.vectorizer.search(query_, top_k)
for skill_title, sim, id in zip(skill_titles, sorted_similarities, document_ids): for chunk in chunks:
if np.linalg.norm(sim[1])<max_dist: if chunk.distance<max_dist:
conn = sqlite3.connect(self.db_path) skills.append(chunk.text)
cursor = conn.cursor() skill_titles.append(chunk.doc.title)
# conn = sqlite3.connect(self.db_path)
# cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression. # Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection. # Ensure text is safely escaped to avoid SQL injection.
query = "SELECT content FROM skills_library WHERE id = ?" #query = "SELECT content FROM skills_library WHERE id = ?"
cursor.execute(query, (id,)) #cursor.execute(query, (chunk.chunk_id,))
res = cursor.fetchall() #res = cursor.fetchall()
skills.append(res[0]) #skills.append(res[0])
cursor.close() #cursor.close()
conn.close() #conn.close()
else:
skill_titles = []
return skill_titles, skills return skill_titles, skills