upgraded securities

This commit is contained in:
Saifeddine ALOUI 2024-03-15 23:00:32 +01:00
parent 4c1e7c9b08
commit 3c2c2069c0
5 changed files with 129 additions and 39 deletions

View File

@ -635,7 +635,7 @@ class LollmsApplication(LoLLMsCom):
discussion = self.recover_discussion(client_id)
query = self.personality.fast_gen(f"!@>discussion:\n{discussion[-2048:]}\n!@>system: Read the discussion and craft a short skills database search query suited to recover needed information to reply to last {self.config.user_name} message.\nDo not answer the prompt. Do not add explanations.\n!@>search query: ", max_generation_size=256, show_progress=True, callback=self.personality.sink)
# skills = self.skills_library.query_entry(query)
skills = self.skills_library.query_entry_fts(query)
skills, sorted_similarities, document_ids = self.skills_library.query_vector_db(query, top_k=3, max_dist=1000)#query_entry_fts(query)
if len(skills)>0:
if knowledge=="":

View File

@ -1,10 +1,12 @@
import sqlite3
from safe_store.text_vectorizer import TextVectorizer, VectorizationMethod, VisualizationMethod
class SkillsLibrary:
def __init__(self, db_path):
self.db_path =db_path
self._initialize_db()
self.vectorizer = TextVectorizer(VectorizationMethod.TFIDF_VECTORIZER)
def _initialize_db(self):
conn = sqlite3.connect(self.db_path)
@ -118,7 +120,74 @@ class SkillsLibrary:
conn.close()
return res
def query_vector_db(self, query, top_k=3, max_dist=1000):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection.
query = "SELECT title FROM skills_library"
cursor.execute(query)
res = cursor.fetchall()
cursor.close()
conn.close()
for entry in res:
self.vectorizer.add_document(entry[0])
self.vectorizer.index()
skill_titles, sorted_similarities, document_ids = self.vectorizer.recover_text(query, top_k)
skills = []
for skill, sim in zip(skill_titles, sorted_similarities):
if sim>max_dist:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection.
query = "SELECT content FROM skills_library WHERE title LIKE ?"
res = cursor.execute(query, (skill,))
skills.append(res[0])
cursor.execute(query)
res = cursor.fetchall()
cursor.close()
conn.close()
return skill_titles, skills
def dump(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection.
query = "SELECT * FROM skills_library"
cursor.execute(query)
res = cursor.fetchall()
cursor.close()
conn.close()
return [[r[0], r[1], r[2], r[3]] for r in res]
def get_categories(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection.
query = "SELECT category FROM skills_library"
cursor.execute(query)
res = cursor.fetchall()
cursor.close()
conn.close()
return [r[0] for r in res]
def get_titles(self, category):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Use direct string concatenation for the MATCH expression.
# Ensure text is safely escaped to avoid SQL injection.
query = "SELECT title FROM skills_library WHERE category=?"
cursor.execute(query,(category,))
res = cursor.fetchall()
cursor.close()
conn.close()
return [r[0] for r in res]
def remove_entry(self, id):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()

View File

@ -11,24 +11,19 @@ import re
def sanitize_path(path:str, allow_absolute_path:bool=False, error_text="Absolute database path detected", exception_text="Detected an attempt of path traversal. Are you kidding me?"):
if path is None:
return path
if(".." in path):
ASCIIColors.warning(error_text)
raise Exception(exception_text)
if (not allow_absolute_path) and Path(path).is_absolute():
ASCIIColors.warning(error_text)
raise Exception(exception_text)
return path
def sanitize_path_from_endpoint(path:str, error_text="A suspected LFI attack detected. The path sent to the server has .. in it!", exception_text="Invalid path!"):
if path is None:
return path
if (".." in path or Path(path).is_absolute()):
# Regular expression to detect patterns like "...." and multiple forward slashes
suspicious_patterns = re.compile(r'(\.\.+)|(/+/)')
if suspicious_patterns.search(path) or ((not allow_absolute_path) and Path(path).is_absolute()):
ASCIIColors.error(error_text)
raise HTTPException(status_code=400, detail=exception_text)
return path
def sanitize_path_from_endpoint(path: str, error_text="A suspected LFI attack detected. The path sent to the server has suspicious elements in it!", exception_text="Invalid path!"):
# Fix the case of "/" at the beginning on the path
path = path.lstrip('/')
if path is None:
return path
@ -41,6 +36,7 @@ def sanitize_path_from_endpoint(path: str, error_text="A suspected LFI attack de
return path
def forbid_remote_access(lollmsElfServer):
if lollmsElfServer.config.host!="localhost" and lollmsElfServer.config.host!="127.0.0.1":
raise Exception("This functionality is forbidden if the server is exposed")

View File

@ -39,9 +39,11 @@ async def serve_user_infos(path: str):
Returns:
FileResponse: The file response containing the requested file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
file_path = (lollmsElfServer.lollms_paths.personal_user_infos_path / path).resolve()
file_path:Path = lollmsElfServer.lollms_paths.personal_user_infos_path / path
if not file_path.exists():
raise HTTPException(status_code=400, detail="File not found")
return FileResponse(str(file_path))
# ----------------------------------- Lollms zoos -----------------------------------------
@ -57,11 +59,11 @@ async def serve_bindings(path: str):
FileResponse: The file response containing the requested bindings file.
"""
sanitize_path_from_endpoint(path)
file_path = (lollmsElfServer.lollms_paths.bindings_zoo_path / path).resolve()
path = sanitize_path_from_endpoint(path)
file_path = lollmsElfServer.lollms_paths.bindings_zoo_path / path
if not Path(file_path).exists():
raise ValueError("File not found")
raise HTTPException(status_code=400, detail="File not found")
return FileResponse(str(file_path))
@router.get("/personalities/{path:path}")
@ -75,15 +77,15 @@ async def serve_personalities(path: str):
Returns:
FileResponse: The file response containing the requested personalities file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
if "custom_personalities" in path:
file_path = (lollmsElfServer.lollms_paths.custom_personalities_path / "/".join(str(path).split("/")[1:])).resolve()
file_path = lollmsElfServer.lollms_paths.custom_personalities_path / "/".join(str(path).split("/")[1:])
else:
file_path = (lollmsElfServer.lollms_paths.personalities_zoo_path / path).resolve()
file_path = lollmsElfServer.lollms_paths.personalities_zoo_path / path
if not Path(file_path).exists():
raise ValueError("File not found")
raise HTTPException(status_code=400, detail="File not found")
return FileResponse(str(file_path))
@ -99,16 +101,12 @@ async def serve_extensions(path: str):
Returns:
FileResponse: The file response containing the requested extensions file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
file_path = (lollmsElfServer.lollms_paths.extensions_zoo_path / path).resolve()
file_path = lollmsElfServer.lollms_paths.extensions_zoo_path / path
if not Path(file_path).exists():
raise ValueError("File not found")
if not Path(file_path).exists():
raise HTTPException(status_code=404, detail="File not found")
raise HTTPException(status_code=400, detail="File not found")
return FileResponse(str(file_path))
@ -125,7 +123,7 @@ async def serve_audio(path: str):
Returns:
FileResponse: The file response containing the requested audio file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = Path(lollmsElfServer.lollms_paths.personal_outputs_path).resolve()
file_path = root_dir/ 'audio_out' / path
@ -147,10 +145,10 @@ async def serve_images(path: str):
Returns:
FileResponse: The file response containing the requested image file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = Path(os.getcwd())/ "images/"
file_path = (root_dir / path).resolve()
file_path = root_dir / path
if not Path(file_path).exists():
raise HTTPException(status_code=404, detail="File not found")
@ -171,7 +169,7 @@ async def serve_outputs(path: str):
Returns:
FileResponse: The file response containing the requested output file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = lollmsElfServer.lollms_paths.personal_outputs_path
root_dir.mkdir(exist_ok=True, parents=True)
@ -195,7 +193,7 @@ async def serve_data(path: str):
Returns:
FileResponse: The file response containing the requested data file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = lollmsElfServer.lollms_paths.personal_path / "data"
root_dir.mkdir(exist_ok=True, parents=True)
@ -220,7 +218,7 @@ async def serve_help(path: str):
Returns:
FileResponse: The file response containing the requested data file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = Path(os.getcwd())
file_path = root_dir/'help/' / path
@ -243,7 +241,7 @@ async def serve_uploads(path: str):
Returns:
FileResponse: The file response containing the requested uploads file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = lollmsElfServer.lollms_paths.personal_path / "uploads"
root_dir.mkdir(exist_ok=True, parents=True)
@ -267,7 +265,7 @@ async def serve_discussions(path: str):
Returns:
FileResponse: The file response containing the requested uploads file.
"""
sanitize_path_from_endpoint(path)
path = sanitize_path_from_endpoint(path)
root_dir = lollmsElfServer.lollms_paths.personal_discussions_path
root_dir.mkdir(exist_ok=True, parents=True)

View File

@ -28,6 +28,33 @@ lollmsElfServer:LOLLMSWebUI = LOLLMSWebUI.get_instance()
class DiscussionInfos(BaseModel):
client_id: str
class CategoryData(BaseModel):
client_id: str
category: str
class CategoryData(BaseModel):
client_id: str
category: str
title: str
@router.post("/get_skills_library")
def get_skills_library_categories(discussionInfos:DiscussionInfos):
return {"status":True, "entries":lollmsElfServer.skills_library.dump()}
@router.post("/get_skills_library_categories")
def get_skills_library_categories(discussionInfos:DiscussionInfos):
return {"status":True, "categories":lollmsElfServer.skills_library.get_categories()}
@router.post("/get_skills_library_titles")
def get_skills_library_categories(categoryData:CategoryData):
return {"status":True, "titles":lollmsElfServer.skills_library.get_titles(categoryData.category)}
@router.post("/get_skills_library_content")
def get_skills_library_categories(categoryData:CategoryData):
return {"status":True, "contents":lollmsElfServer.skills_library.get_titles(categoryData.category)}
@router.post("/add_discussion_to_skills_library")
def add_discussion_to_skills_library(discussionInfos:DiscussionInfos):
lollmsElfServer.ShowBlockingMessage("Learning...")