enhanced lollms client

This commit is contained in:
Saifeddine ALOUI 2024-08-30 00:57:21 +02:00
parent 55989f795b
commit 767644c305
2 changed files with 160 additions and 65 deletions

118
lollms/code_modifier.py Normal file
View File

@ -0,0 +1,118 @@
import os
import subprocess
from typing import Dict, Tuple
import re
class CodeModifier:
def __init__(self, folder_path: str):
self.folder_path = os.path.abspath(folder_path)
self.init_git()
def init_git(self):
if not os.path.exists(os.path.join(self.folder_path, '.git')):
subprocess.run(['git', 'init'], cwd=self.folder_path, check=True)
def generate_llm_prompt(self, file_path: str, modification_instruction: str) -> str:
with open(os.path.join(self.folder_path, file_path), 'r') as file:
content = file.read()
prompt = f"""
Task: Modify the following code according to the instructions.
File: {file_path}
Instructions: {modification_instruction}
Please provide your response in the following format:
1. A unified diff of the changes
2. A commit message for the changes
Current code:
{content}
Make sure to use the unified diff format for your changes, starting with '--- {file_path}' and '+++ {file_path}', followed by '@@ ... @@' for line numbers.
"""
return prompt
def parse_llm_response(self, response: str) -> Tuple[str, str]:
diff_pattern = r'(---[\s\S]*?)(?=\nCommit message:|$)'
commit_pattern = r'Commit message:([\s\S]*)'
diff_match = re.search(diff_pattern, response)
commit_match = re.search(commit_pattern, response)
if not diff_match or not commit_match:
raise ValueError("Invalid LLM response format")
diff = diff_match.group(1).strip()
commit_message = commit_match.group(1).strip()
return diff, commit_message
def apply_diff(self, file_path: str, diff: str) -> None:
with open(os.path.join(self.folder_path, file_path), 'r') as file:
lines = file.readlines()
diff_lines = diff.split('\n')
current_line = 0
for line in diff_lines[2:]: # Skip the first two lines (--- and +++)
if line.startswith('@@'):
match = re.match(r'@@ -\d+,?\d* \+(\d+),?\d* @@', line)
if match:
current_line = int(match.group(1)) - 1
elif line.startswith('-'):
if lines[current_line].strip() == line[1:].strip():
lines.pop(current_line)
else:
raise ValueError(f"Mismatch at line {current_line + 1}")
elif line.startswith('+'):
lines.insert(current_line, line[1:] + '\n')
current_line += 1
else:
current_line += 1
with open(os.path.join(self.folder_path, file_path), 'w') as file:
file.writelines(lines)
def commit_changes(self, file_path: str, commit_message: str) -> None:
try:
subprocess.run(['git', 'add', file_path], cwd=self.folder_path, check=True)
subprocess.run(['git', 'commit', '-m', commit_message], cwd=self.folder_path, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to commit changes: {e}")
def modify_code(self, file_path: str, modification_instruction: str, llm_function) -> Dict[str, str]:
prompt = self.generate_llm_prompt(file_path, modification_instruction)
llm_response = llm_function(prompt)
try:
diff, commit_message = self.parse_llm_response(llm_response)
print("Generated diff:")
print(diff)
self.apply_diff(file_path, diff)
self.commit_changes(file_path, commit_message)
return {"status": "success", "message": "Code modified and committed successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
# Example usage:
def mock_llm_function(prompt):
# This is a mock function to simulate LLM response
return """
--- example.py
+++ example.py
@@ -1,5 +1,6 @@
def hello_world():
- print("Hello, World!")
+ print("Hello, Universe!")
+ return "Greetings from the cosmos!"
hello_world()
Commit message:
Update hello_world function to greet the universe and return a message
"""
if __name__=="__main__":
# Usage
modifier = CodeModifier(r"C:\Users\aloui\Documents\ai\test_code_modif")
result = modifier.modify_code("example.py", "Change the greeting to 'Hello, Universe!' and make the function return a string", mock_llm_function)
print(result)

View File

@ -1,29 +1,6 @@
"""
project: lollms_webui
file: lollms_rag.py
author: ParisNeo
description:
This module contains a set of FastAPI routes that allow users to interact with the RAG (Retrieval-Augmented Generation) library.
Usage:
1. Initialize the RAG system by adding documents using the /add_document endpoint.
2. Build the index using the /index_database endpoint.
3. Perform searches using the /search endpoint.
4. Remove documents using the /remove_document/{document_id} endpoint.
5. Wipe the entire database using the /wipe_database endpoint.
Authentication:
- If lollms_access_keys are specified in the configuration, API key authentication is required.
- If no keys are specified, authentication is bypassed, and all users are treated as user ID 1.
User Management:
- Each user gets a unique vectorizer based on their API key.
- If no API keys are specified, all requests are treated as coming from user ID 1.
Note: Ensure proper security measures are in place when deploying this API in a production environment.
"""
from fastapi import APIRouter, Request, HTTPException, Depends, Header
import uuid
from fastapi import Request, Response
from lollms_webui import LOLLMSWebUI
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
@ -34,7 +11,6 @@ from ascii_colors import ASCIIColors
from lollms.databases.discussions_database import DiscussionsDB, Discussion
from typing import List, Optional, Union
from pathlib import Path
from fastapi.security import APIKeyHeader
from lollmsvectordb.database_elements.chunk import Chunk
from lollmsvectordb.vector_database import VectorDatabase
from lollmsvectordb.lollms_vectorizers.bert_vectorizer import BERTVectorizer
@ -53,12 +29,12 @@ import hashlib
router = APIRouter()
lollmsElfServer: LOLLMSWebUI = LOLLMSWebUI.get_instance()
api_key_header = APIKeyHeader(name="Authorization")
# ----------------------- RAG System ------------------------------
class RAGQuery(BaseModel):
query: str = Field(..., description="The query to process using RAG")
key: str = Field(..., description="The key to identify the user")
class RAGResponse(BaseModel):
answer: str = Field(..., description="The generated answer")
@ -68,6 +44,7 @@ class IndexDocument(BaseModel):
title: str = Field(..., description="The title of the document")
content: str = Field(..., description="The content to be indexed")
path: str = Field(default="unknown", description="The path of the document")
key: str = Field(..., description="The key to identify the user")
class IndexResponse(BaseModel):
success: bool = Field(..., description="Indicates if the indexing was successful")
@ -86,22 +63,9 @@ class RAGChunk(BaseModel):
nb_tokens : int
distance : float
def get_user_id(bearer_key: str) -> int:
"""
Determine the user ID based on the bearer key.
If no keys are specified in the configuration, always return 1.
"""
if not lollmsElfServer.config.lollms_access_keys:
return 1
# Use the index of the key in the list as the user ID
try:
return lollmsElfServer.config.lollms_access_keys.index(bearer_key) + 1
except ValueError:
raise HTTPException(status_code=403, detail="Invalid API Key")
def get_user_vectorizer(user_id: int, bearer_key: str):
small_key = hashlib.md5(bearer_key.encode()).hexdigest()[:8]
user_folder = lollmsElfServer.lollms_paths / str(user_id)
def get_user_vectorizer(user_key: str):
small_key = hashlib.md5(user_key.encode()).hexdigest()[:8]
user_folder = lollmsElfServer.lollms_paths.personal_outputs_path / str(user_key)
user_folder.mkdir(parents=True, exist_ok=True)
return VectorDatabase(
str(user_folder / f"rag_db_{small_key}.sqlite"),
@ -111,45 +75,58 @@ def get_user_vectorizer(user_id: int, bearer_key: str):
overlap=lollmsElfServer.config.rag_overlap
)
async def get_current_user(bearer_token: str = Depends(api_key_header)):
if lollmsElfServer.config.lollms_access_keys:
if bearer_token not in lollmsElfServer.config.lollms_access_keys:
raise HTTPException(status_code=403, detail="Invalid API Key")
return bearer_token
async def validate_key(key: str):
if lollmsElfServer.config.lollms_access_keys and key not in lollmsElfServer.config.lollms_access_keys:
raise HTTPException(status_code=403, detail="Invalid Key")
return key
@router.post("/add_document", response_model=DocumentResponse)
async def add_document(doc: IndexDocument, user: str = Depends(get_current_user)):
user_id = get_user_id(user)
vectorizer = get_user_vectorizer(user_id, user)
async def add_document(doc: IndexDocument):
await validate_key(doc.key)
vectorizer = get_user_vectorizer(doc.key)
vectorizer.add_document(title=doc.title, text=doc.content, path=doc.path)
return DocumentResponse(success=True, message="Document added successfully.")
@router.post("/remove_document/{document_id}", response_model=DocumentResponse)
async def remove_document(document_id: int, user: str = Depends(get_current_user)):
user_id = get_user_id(user)
vectorizer = get_user_vectorizer(user_id, user)
async def remove_document(document_id: int, key: str):
await validate_key(key)
vectorizer = get_user_vectorizer(key)
doc_hash = vectorizer.get_document_hash(document_id)
vectorizer.remove_document(doc_hash)
# Logic to remove the document by ID
return DocumentResponse(success=True, message="Document removed successfully.")
class IndexDatabaseRequest(BaseModel):
key: str
@router.post("/index_database", response_model=DocumentResponse)
async def index_database(user: str = Depends(get_current_user)):
user_id = get_user_id(user)
vectorizer = get_user_vectorizer(user_id, user)
async def index_database(request: IndexDatabaseRequest):
key = request.key
await validate_key(key)
vectorizer = get_user_vectorizer(key)
vectorizer.build_index()
return DocumentResponse(success=True, message="Database indexed successfully.")
@router.post("/search", response_model=List[RAGChunk])
async def search(query: RAGQuery, user: str = Depends(get_current_user)):
user_id = get_user_id(user)
vectorizer = get_user_vectorizer(user_id, user)
async def search(query: RAGQuery):
await validate_key(query.key)
vectorizer = get_user_vectorizer(query.key)
chunks = vectorizer.search(query.query)
return [RAGChunk(c.id,c.chunk_id, c.doc.title, c.doc.path, c.text, c.nb_tokens, c.distance) for c in chunks]
return [
RAGChunk(
id=c.id,
chunk_id=c.chunk_id,
title=c.doc.title,
path=c.doc.path,
text=c.text,
nb_tokens=c.nb_tokens,
distance=c.distance
)
for c in chunks
]
@router.delete("/wipe_database", response_model=DocumentResponse)
async def wipe_database(user: str = Depends(get_current_user)):
user_id = get_user_id(user)
user_folder = lollmsElfServer.lollms_paths / str(user_id)
async def wipe_database(key: str):
await validate_key(key)
user_folder = lollmsElfServer.lollms_paths / str(key)
shutil.rmtree(user_folder, ignore_errors=True)
return DocumentResponse(success=True, message="Database wiped successfully.")