mirror of
https://github.com/ParisNeo/lollms.git
synced 2024-12-18 20:27:58 +00:00
enhanced lollms client
This commit is contained in:
parent
55989f795b
commit
767644c305
118
lollms/code_modifier.py
Normal file
118
lollms/code_modifier.py
Normal file
@ -0,0 +1,118 @@
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Dict, Tuple
|
||||
import re
|
||||
|
||||
class CodeModifier:
|
||||
def __init__(self, folder_path: str):
|
||||
self.folder_path = os.path.abspath(folder_path)
|
||||
self.init_git()
|
||||
|
||||
def init_git(self):
|
||||
if not os.path.exists(os.path.join(self.folder_path, '.git')):
|
||||
subprocess.run(['git', 'init'], cwd=self.folder_path, check=True)
|
||||
|
||||
def generate_llm_prompt(self, file_path: str, modification_instruction: str) -> str:
|
||||
with open(os.path.join(self.folder_path, file_path), 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
prompt = f"""
|
||||
Task: Modify the following code according to the instructions.
|
||||
File: {file_path}
|
||||
Instructions: {modification_instruction}
|
||||
|
||||
Please provide your response in the following format:
|
||||
1. A unified diff of the changes
|
||||
2. A commit message for the changes
|
||||
|
||||
Current code:
|
||||
{content}
|
||||
|
||||
Make sure to use the unified diff format for your changes, starting with '--- {file_path}' and '+++ {file_path}', followed by '@@ ... @@' for line numbers.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def parse_llm_response(self, response: str) -> Tuple[str, str]:
|
||||
diff_pattern = r'(---[\s\S]*?)(?=\nCommit message:|$)'
|
||||
commit_pattern = r'Commit message:([\s\S]*)'
|
||||
|
||||
diff_match = re.search(diff_pattern, response)
|
||||
commit_match = re.search(commit_pattern, response)
|
||||
|
||||
if not diff_match or not commit_match:
|
||||
raise ValueError("Invalid LLM response format")
|
||||
|
||||
diff = diff_match.group(1).strip()
|
||||
commit_message = commit_match.group(1).strip()
|
||||
|
||||
return diff, commit_message
|
||||
|
||||
def apply_diff(self, file_path: str, diff: str) -> None:
|
||||
with open(os.path.join(self.folder_path, file_path), 'r') as file:
|
||||
lines = file.readlines()
|
||||
|
||||
diff_lines = diff.split('\n')
|
||||
current_line = 0
|
||||
for line in diff_lines[2:]: # Skip the first two lines (--- and +++)
|
||||
if line.startswith('@@'):
|
||||
match = re.match(r'@@ -\d+,?\d* \+(\d+),?\d* @@', line)
|
||||
if match:
|
||||
current_line = int(match.group(1)) - 1
|
||||
elif line.startswith('-'):
|
||||
if lines[current_line].strip() == line[1:].strip():
|
||||
lines.pop(current_line)
|
||||
else:
|
||||
raise ValueError(f"Mismatch at line {current_line + 1}")
|
||||
elif line.startswith('+'):
|
||||
lines.insert(current_line, line[1:] + '\n')
|
||||
current_line += 1
|
||||
else:
|
||||
current_line += 1
|
||||
|
||||
with open(os.path.join(self.folder_path, file_path), 'w') as file:
|
||||
file.writelines(lines)
|
||||
|
||||
def commit_changes(self, file_path: str, commit_message: str) -> None:
|
||||
try:
|
||||
subprocess.run(['git', 'add', file_path], cwd=self.folder_path, check=True)
|
||||
subprocess.run(['git', 'commit', '-m', commit_message], cwd=self.folder_path, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to commit changes: {e}")
|
||||
|
||||
def modify_code(self, file_path: str, modification_instruction: str, llm_function) -> Dict[str, str]:
|
||||
prompt = self.generate_llm_prompt(file_path, modification_instruction)
|
||||
llm_response = llm_function(prompt)
|
||||
|
||||
try:
|
||||
diff, commit_message = self.parse_llm_response(llm_response)
|
||||
print("Generated diff:")
|
||||
print(diff)
|
||||
self.apply_diff(file_path, diff)
|
||||
self.commit_changes(file_path, commit_message)
|
||||
return {"status": "success", "message": "Code modified and committed successfully"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
# Example usage:
|
||||
def mock_llm_function(prompt):
|
||||
# This is a mock function to simulate LLM response
|
||||
return """
|
||||
--- example.py
|
||||
+++ example.py
|
||||
@@ -1,5 +1,6 @@
|
||||
def hello_world():
|
||||
- print("Hello, World!")
|
||||
+ print("Hello, Universe!")
|
||||
+ return "Greetings from the cosmos!"
|
||||
|
||||
hello_world()
|
||||
|
||||
Commit message:
|
||||
Update hello_world function to greet the universe and return a message
|
||||
"""
|
||||
|
||||
if __name__=="__main__":
|
||||
# Usage
|
||||
modifier = CodeModifier(r"C:\Users\aloui\Documents\ai\test_code_modif")
|
||||
result = modifier.modify_code("example.py", "Change the greeting to 'Hello, Universe!' and make the function return a string", mock_llm_function)
|
||||
print(result)
|
@ -1,29 +1,6 @@
|
||||
"""
|
||||
project: lollms_webui
|
||||
file: lollms_rag.py
|
||||
author: ParisNeo
|
||||
description:
|
||||
This module contains a set of FastAPI routes that allow users to interact with the RAG (Retrieval-Augmented Generation) library.
|
||||
|
||||
Usage:
|
||||
1. Initialize the RAG system by adding documents using the /add_document endpoint.
|
||||
2. Build the index using the /index_database endpoint.
|
||||
3. Perform searches using the /search endpoint.
|
||||
4. Remove documents using the /remove_document/{document_id} endpoint.
|
||||
5. Wipe the entire database using the /wipe_database endpoint.
|
||||
|
||||
Authentication:
|
||||
- If lollms_access_keys are specified in the configuration, API key authentication is required.
|
||||
- If no keys are specified, authentication is bypassed, and all users are treated as user ID 1.
|
||||
|
||||
User Management:
|
||||
- Each user gets a unique vectorizer based on their API key.
|
||||
- If no API keys are specified, all requests are treated as coming from user ID 1.
|
||||
|
||||
Note: Ensure proper security measures are in place when deploying this API in a production environment.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException, Depends, Header
|
||||
import uuid
|
||||
from fastapi import Request, Response
|
||||
from lollms_webui import LOLLMSWebUI
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
@ -34,7 +11,6 @@ from ascii_colors import ASCIIColors
|
||||
from lollms.databases.discussions_database import DiscussionsDB, Discussion
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
from fastapi.security import APIKeyHeader
|
||||
from lollmsvectordb.database_elements.chunk import Chunk
|
||||
from lollmsvectordb.vector_database import VectorDatabase
|
||||
from lollmsvectordb.lollms_vectorizers.bert_vectorizer import BERTVectorizer
|
||||
@ -53,12 +29,12 @@ import hashlib
|
||||
|
||||
router = APIRouter()
|
||||
lollmsElfServer: LOLLMSWebUI = LOLLMSWebUI.get_instance()
|
||||
api_key_header = APIKeyHeader(name="Authorization")
|
||||
|
||||
# ----------------------- RAG System ------------------------------
|
||||
|
||||
class RAGQuery(BaseModel):
|
||||
query: str = Field(..., description="The query to process using RAG")
|
||||
key: str = Field(..., description="The key to identify the user")
|
||||
|
||||
class RAGResponse(BaseModel):
|
||||
answer: str = Field(..., description="The generated answer")
|
||||
@ -68,6 +44,7 @@ class IndexDocument(BaseModel):
|
||||
title: str = Field(..., description="The title of the document")
|
||||
content: str = Field(..., description="The content to be indexed")
|
||||
path: str = Field(default="unknown", description="The path of the document")
|
||||
key: str = Field(..., description="The key to identify the user")
|
||||
|
||||
class IndexResponse(BaseModel):
|
||||
success: bool = Field(..., description="Indicates if the indexing was successful")
|
||||
@ -86,22 +63,9 @@ class RAGChunk(BaseModel):
|
||||
nb_tokens : int
|
||||
distance : float
|
||||
|
||||
def get_user_id(bearer_key: str) -> int:
|
||||
"""
|
||||
Determine the user ID based on the bearer key.
|
||||
If no keys are specified in the configuration, always return 1.
|
||||
"""
|
||||
if not lollmsElfServer.config.lollms_access_keys:
|
||||
return 1
|
||||
# Use the index of the key in the list as the user ID
|
||||
try:
|
||||
return lollmsElfServer.config.lollms_access_keys.index(bearer_key) + 1
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Invalid API Key")
|
||||
|
||||
def get_user_vectorizer(user_id: int, bearer_key: str):
|
||||
small_key = hashlib.md5(bearer_key.encode()).hexdigest()[:8]
|
||||
user_folder = lollmsElfServer.lollms_paths / str(user_id)
|
||||
def get_user_vectorizer(user_key: str):
|
||||
small_key = hashlib.md5(user_key.encode()).hexdigest()[:8]
|
||||
user_folder = lollmsElfServer.lollms_paths.personal_outputs_path / str(user_key)
|
||||
user_folder.mkdir(parents=True, exist_ok=True)
|
||||
return VectorDatabase(
|
||||
str(user_folder / f"rag_db_{small_key}.sqlite"),
|
||||
@ -111,45 +75,58 @@ def get_user_vectorizer(user_id: int, bearer_key: str):
|
||||
overlap=lollmsElfServer.config.rag_overlap
|
||||
)
|
||||
|
||||
async def get_current_user(bearer_token: str = Depends(api_key_header)):
|
||||
if lollmsElfServer.config.lollms_access_keys:
|
||||
if bearer_token not in lollmsElfServer.config.lollms_access_keys:
|
||||
raise HTTPException(status_code=403, detail="Invalid API Key")
|
||||
return bearer_token
|
||||
async def validate_key(key: str):
|
||||
if lollmsElfServer.config.lollms_access_keys and key not in lollmsElfServer.config.lollms_access_keys:
|
||||
raise HTTPException(status_code=403, detail="Invalid Key")
|
||||
return key
|
||||
|
||||
@router.post("/add_document", response_model=DocumentResponse)
|
||||
async def add_document(doc: IndexDocument, user: str = Depends(get_current_user)):
|
||||
user_id = get_user_id(user)
|
||||
vectorizer = get_user_vectorizer(user_id, user)
|
||||
async def add_document(doc: IndexDocument):
|
||||
await validate_key(doc.key)
|
||||
vectorizer = get_user_vectorizer(doc.key)
|
||||
vectorizer.add_document(title=doc.title, text=doc.content, path=doc.path)
|
||||
return DocumentResponse(success=True, message="Document added successfully.")
|
||||
|
||||
@router.post("/remove_document/{document_id}", response_model=DocumentResponse)
|
||||
async def remove_document(document_id: int, user: str = Depends(get_current_user)):
|
||||
user_id = get_user_id(user)
|
||||
vectorizer = get_user_vectorizer(user_id, user)
|
||||
async def remove_document(document_id: int, key: str):
|
||||
await validate_key(key)
|
||||
vectorizer = get_user_vectorizer(key)
|
||||
doc_hash = vectorizer.get_document_hash(document_id)
|
||||
vectorizer.remove_document(doc_hash)
|
||||
# Logic to remove the document by ID
|
||||
return DocumentResponse(success=True, message="Document removed successfully.")
|
||||
|
||||
class IndexDatabaseRequest(BaseModel):
|
||||
key: str
|
||||
|
||||
@router.post("/index_database", response_model=DocumentResponse)
|
||||
async def index_database(user: str = Depends(get_current_user)):
|
||||
user_id = get_user_id(user)
|
||||
vectorizer = get_user_vectorizer(user_id, user)
|
||||
async def index_database(request: IndexDatabaseRequest):
|
||||
key = request.key
|
||||
await validate_key(key)
|
||||
vectorizer = get_user_vectorizer(key)
|
||||
vectorizer.build_index()
|
||||
return DocumentResponse(success=True, message="Database indexed successfully.")
|
||||
|
||||
@router.post("/search", response_model=List[RAGChunk])
|
||||
async def search(query: RAGQuery, user: str = Depends(get_current_user)):
|
||||
user_id = get_user_id(user)
|
||||
vectorizer = get_user_vectorizer(user_id, user)
|
||||
async def search(query: RAGQuery):
|
||||
await validate_key(query.key)
|
||||
vectorizer = get_user_vectorizer(query.key)
|
||||
chunks = vectorizer.search(query.query)
|
||||
return [RAGChunk(c.id,c.chunk_id, c.doc.title, c.doc.path, c.text, c.nb_tokens, c.distance) for c in chunks]
|
||||
return [
|
||||
RAGChunk(
|
||||
id=c.id,
|
||||
chunk_id=c.chunk_id,
|
||||
title=c.doc.title,
|
||||
path=c.doc.path,
|
||||
text=c.text,
|
||||
nb_tokens=c.nb_tokens,
|
||||
distance=c.distance
|
||||
)
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
@router.delete("/wipe_database", response_model=DocumentResponse)
|
||||
async def wipe_database(user: str = Depends(get_current_user)):
|
||||
user_id = get_user_id(user)
|
||||
user_folder = lollmsElfServer.lollms_paths / str(user_id)
|
||||
async def wipe_database(key: str):
|
||||
await validate_key(key)
|
||||
user_folder = lollmsElfServer.lollms_paths / str(key)
|
||||
shutil.rmtree(user_folder, ignore_errors=True)
|
||||
return DocumentResponse(success=True, message="Database wiped successfully.")
|
||||
|
Loading…
Reference in New Issue
Block a user