From 767644c30551c02f8783fad51ebba62ce86f67a1 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Fri, 30 Aug 2024 00:57:21 +0200 Subject: [PATCH] enhanced lollms client --- lollms/code_modifier.py | 118 ++++++++++++++++++++++++++ lollms/server/endpoints/lollms_rag.py | 107 +++++++++-------------- 2 files changed, 160 insertions(+), 65 deletions(-) create mode 100644 lollms/code_modifier.py diff --git a/lollms/code_modifier.py b/lollms/code_modifier.py new file mode 100644 index 0000000..27bd011 --- /dev/null +++ b/lollms/code_modifier.py @@ -0,0 +1,118 @@ +import os +import subprocess +from typing import Dict, Tuple +import re + +class CodeModifier: + def __init__(self, folder_path: str): + self.folder_path = os.path.abspath(folder_path) + self.init_git() + + def init_git(self): + if not os.path.exists(os.path.join(self.folder_path, '.git')): + subprocess.run(['git', 'init'], cwd=self.folder_path, check=True) + + def generate_llm_prompt(self, file_path: str, modification_instruction: str) -> str: + with open(os.path.join(self.folder_path, file_path), 'r') as file: + content = file.read() + + prompt = f""" +Task: Modify the following code according to the instructions. +File: {file_path} +Instructions: {modification_instruction} + +Please provide your response in the following format: +1. A unified diff of the changes +2. A commit message for the changes + +Current code: +{content} + +Make sure to use the unified diff format for your changes, starting with '--- {file_path}' and '+++ {file_path}', followed by '@@ ... @@' for line numbers. +""" + return prompt + + def parse_llm_response(self, response: str) -> Tuple[str, str]: + diff_pattern = r'(---[\s\S]*?)(?=\nCommit message:|$)' + commit_pattern = r'Commit message:([\s\S]*)' + + diff_match = re.search(diff_pattern, response) + commit_match = re.search(commit_pattern, response) + + if not diff_match or not commit_match: + raise ValueError("Invalid LLM response format") + + diff = diff_match.group(1).strip() + commit_message = commit_match.group(1).strip() + + return diff, commit_message + + def apply_diff(self, file_path: str, diff: str) -> None: + with open(os.path.join(self.folder_path, file_path), 'r') as file: + lines = file.readlines() + + diff_lines = diff.split('\n') + current_line = 0 + for line in diff_lines[2:]: # Skip the first two lines (--- and +++) + if line.startswith('@@'): + match = re.match(r'@@ -\d+,?\d* \+(\d+),?\d* @@', line) + if match: + current_line = int(match.group(1)) - 1 + elif line.startswith('-'): + if lines[current_line].strip() == line[1:].strip(): + lines.pop(current_line) + else: + raise ValueError(f"Mismatch at line {current_line + 1}") + elif line.startswith('+'): + lines.insert(current_line, line[1:] + '\n') + current_line += 1 + else: + current_line += 1 + + with open(os.path.join(self.folder_path, file_path), 'w') as file: + file.writelines(lines) + + def commit_changes(self, file_path: str, commit_message: str) -> None: + try: + subprocess.run(['git', 'add', file_path], cwd=self.folder_path, check=True) + subprocess.run(['git', 'commit', '-m', commit_message], cwd=self.folder_path, check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to commit changes: {e}") + + def modify_code(self, file_path: str, modification_instruction: str, llm_function) -> Dict[str, str]: + prompt = self.generate_llm_prompt(file_path, modification_instruction) + llm_response = llm_function(prompt) + + try: + diff, commit_message = self.parse_llm_response(llm_response) + print("Generated diff:") + print(diff) + self.apply_diff(file_path, diff) + self.commit_changes(file_path, commit_message) + return {"status": "success", "message": "Code modified and committed successfully"} + except Exception as e: + return {"status": "error", "message": str(e)} + +# Example usage: +def mock_llm_function(prompt): + # This is a mock function to simulate LLM response + return """ +--- example.py ++++ example.py +@@ -1,5 +1,6 @@ + def hello_world(): +- print("Hello, World!") ++ print("Hello, Universe!") ++ return "Greetings from the cosmos!" + + hello_world() + +Commit message: +Update hello_world function to greet the universe and return a message +""" + +if __name__=="__main__": + # Usage + modifier = CodeModifier(r"C:\Users\aloui\Documents\ai\test_code_modif") + result = modifier.modify_code("example.py", "Change the greeting to 'Hello, Universe!' and make the function return a string", mock_llm_function) + print(result) diff --git a/lollms/server/endpoints/lollms_rag.py b/lollms/server/endpoints/lollms_rag.py index 044cd89..64fd01d 100644 --- a/lollms/server/endpoints/lollms_rag.py +++ b/lollms/server/endpoints/lollms_rag.py @@ -1,29 +1,6 @@ -""" -project: lollms_webui -file: lollms_rag.py -author: ParisNeo -description: - This module contains a set of FastAPI routes that allow users to interact with the RAG (Retrieval-Augmented Generation) library. - -Usage: - 1. Initialize the RAG system by adding documents using the /add_document endpoint. - 2. Build the index using the /index_database endpoint. - 3. Perform searches using the /search endpoint. - 4. Remove documents using the /remove_document/{document_id} endpoint. - 5. Wipe the entire database using the /wipe_database endpoint. - -Authentication: - - If lollms_access_keys are specified in the configuration, API key authentication is required. - - If no keys are specified, authentication is bypassed, and all users are treated as user ID 1. - -User Management: - - Each user gets a unique vectorizer based on their API key. - - If no API keys are specified, all requests are treated as coming from user ID 1. - -Note: Ensure proper security measures are in place when deploying this API in a production environment. -""" - from fastapi import APIRouter, Request, HTTPException, Depends, Header +import uuid +from fastapi import Request, Response from lollms_webui import LOLLMSWebUI from pydantic import BaseModel, Field from starlette.responses import StreamingResponse @@ -34,7 +11,6 @@ from ascii_colors import ASCIIColors from lollms.databases.discussions_database import DiscussionsDB, Discussion from typing import List, Optional, Union from pathlib import Path -from fastapi.security import APIKeyHeader from lollmsvectordb.database_elements.chunk import Chunk from lollmsvectordb.vector_database import VectorDatabase from lollmsvectordb.lollms_vectorizers.bert_vectorizer import BERTVectorizer @@ -53,12 +29,12 @@ import hashlib router = APIRouter() lollmsElfServer: LOLLMSWebUI = LOLLMSWebUI.get_instance() -api_key_header = APIKeyHeader(name="Authorization") # ----------------------- RAG System ------------------------------ class RAGQuery(BaseModel): query: str = Field(..., description="The query to process using RAG") + key: str = Field(..., description="The key to identify the user") class RAGResponse(BaseModel): answer: str = Field(..., description="The generated answer") @@ -68,6 +44,7 @@ class IndexDocument(BaseModel): title: str = Field(..., description="The title of the document") content: str = Field(..., description="The content to be indexed") path: str = Field(default="unknown", description="The path of the document") + key: str = Field(..., description="The key to identify the user") class IndexResponse(BaseModel): success: bool = Field(..., description="Indicates if the indexing was successful") @@ -86,22 +63,9 @@ class RAGChunk(BaseModel): nb_tokens : int distance : float -def get_user_id(bearer_key: str) -> int: - """ - Determine the user ID based on the bearer key. - If no keys are specified in the configuration, always return 1. - """ - if not lollmsElfServer.config.lollms_access_keys: - return 1 - # Use the index of the key in the list as the user ID - try: - return lollmsElfServer.config.lollms_access_keys.index(bearer_key) + 1 - except ValueError: - raise HTTPException(status_code=403, detail="Invalid API Key") - -def get_user_vectorizer(user_id: int, bearer_key: str): - small_key = hashlib.md5(bearer_key.encode()).hexdigest()[:8] - user_folder = lollmsElfServer.lollms_paths / str(user_id) +def get_user_vectorizer(user_key: str): + small_key = hashlib.md5(user_key.encode()).hexdigest()[:8] + user_folder = lollmsElfServer.lollms_paths.personal_outputs_path / str(user_key) user_folder.mkdir(parents=True, exist_ok=True) return VectorDatabase( str(user_folder / f"rag_db_{small_key}.sqlite"), @@ -111,45 +75,58 @@ def get_user_vectorizer(user_id: int, bearer_key: str): overlap=lollmsElfServer.config.rag_overlap ) -async def get_current_user(bearer_token: str = Depends(api_key_header)): - if lollmsElfServer.config.lollms_access_keys: - if bearer_token not in lollmsElfServer.config.lollms_access_keys: - raise HTTPException(status_code=403, detail="Invalid API Key") - return bearer_token +async def validate_key(key: str): + if lollmsElfServer.config.lollms_access_keys and key not in lollmsElfServer.config.lollms_access_keys: + raise HTTPException(status_code=403, detail="Invalid Key") + return key @router.post("/add_document", response_model=DocumentResponse) -async def add_document(doc: IndexDocument, user: str = Depends(get_current_user)): - user_id = get_user_id(user) - vectorizer = get_user_vectorizer(user_id, user) +async def add_document(doc: IndexDocument): + await validate_key(doc.key) + vectorizer = get_user_vectorizer(doc.key) vectorizer.add_document(title=doc.title, text=doc.content, path=doc.path) return DocumentResponse(success=True, message="Document added successfully.") @router.post("/remove_document/{document_id}", response_model=DocumentResponse) -async def remove_document(document_id: int, user: str = Depends(get_current_user)): - user_id = get_user_id(user) - vectorizer = get_user_vectorizer(user_id, user) +async def remove_document(document_id: int, key: str): + await validate_key(key) + vectorizer = get_user_vectorizer(key) doc_hash = vectorizer.get_document_hash(document_id) vectorizer.remove_document(doc_hash) - # Logic to remove the document by ID return DocumentResponse(success=True, message="Document removed successfully.") +class IndexDatabaseRequest(BaseModel): + key: str + @router.post("/index_database", response_model=DocumentResponse) -async def index_database(user: str = Depends(get_current_user)): - user_id = get_user_id(user) - vectorizer = get_user_vectorizer(user_id, user) +async def index_database(request: IndexDatabaseRequest): + key = request.key + await validate_key(key) + vectorizer = get_user_vectorizer(key) vectorizer.build_index() return DocumentResponse(success=True, message="Database indexed successfully.") @router.post("/search", response_model=List[RAGChunk]) -async def search(query: RAGQuery, user: str = Depends(get_current_user)): - user_id = get_user_id(user) - vectorizer = get_user_vectorizer(user_id, user) +async def search(query: RAGQuery): + await validate_key(query.key) + vectorizer = get_user_vectorizer(query.key) chunks = vectorizer.search(query.query) - return [RAGChunk(c.id,c.chunk_id, c.doc.title, c.doc.path, c.text, c.nb_tokens, c.distance) for c in chunks] + return [ + RAGChunk( + id=c.id, + chunk_id=c.chunk_id, + title=c.doc.title, + path=c.doc.path, + text=c.text, + nb_tokens=c.nb_tokens, + distance=c.distance + ) + for c in chunks +] @router.delete("/wipe_database", response_model=DocumentResponse) -async def wipe_database(user: str = Depends(get_current_user)): - user_id = get_user_id(user) - user_folder = lollmsElfServer.lollms_paths / str(user_id) +async def wipe_database(key: str): + await validate_key(key) + user_folder = lollmsElfServer.lollms_paths / str(key) shutil.rmtree(user_folder, ignore_errors=True) return DocumentResponse(success=True, message="Database wiped successfully.")