mirror of
https://github.com/ParisNeo/lollms.git
synced 2024-12-20 05:08:00 +00:00
enhanced lollms client
This commit is contained in:
parent
55989f795b
commit
767644c305
118
lollms/code_modifier.py
Normal file
118
lollms/code_modifier.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
import re
|
||||||
|
|
||||||
|
class CodeModifier:
|
||||||
|
def __init__(self, folder_path: str):
|
||||||
|
self.folder_path = os.path.abspath(folder_path)
|
||||||
|
self.init_git()
|
||||||
|
|
||||||
|
def init_git(self):
|
||||||
|
if not os.path.exists(os.path.join(self.folder_path, '.git')):
|
||||||
|
subprocess.run(['git', 'init'], cwd=self.folder_path, check=True)
|
||||||
|
|
||||||
|
def generate_llm_prompt(self, file_path: str, modification_instruction: str) -> str:
|
||||||
|
with open(os.path.join(self.folder_path, file_path), 'r') as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Task: Modify the following code according to the instructions.
|
||||||
|
File: {file_path}
|
||||||
|
Instructions: {modification_instruction}
|
||||||
|
|
||||||
|
Please provide your response in the following format:
|
||||||
|
1. A unified diff of the changes
|
||||||
|
2. A commit message for the changes
|
||||||
|
|
||||||
|
Current code:
|
||||||
|
{content}
|
||||||
|
|
||||||
|
Make sure to use the unified diff format for your changes, starting with '--- {file_path}' and '+++ {file_path}', followed by '@@ ... @@' for line numbers.
|
||||||
|
"""
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def parse_llm_response(self, response: str) -> Tuple[str, str]:
|
||||||
|
diff_pattern = r'(---[\s\S]*?)(?=\nCommit message:|$)'
|
||||||
|
commit_pattern = r'Commit message:([\s\S]*)'
|
||||||
|
|
||||||
|
diff_match = re.search(diff_pattern, response)
|
||||||
|
commit_match = re.search(commit_pattern, response)
|
||||||
|
|
||||||
|
if not diff_match or not commit_match:
|
||||||
|
raise ValueError("Invalid LLM response format")
|
||||||
|
|
||||||
|
diff = diff_match.group(1).strip()
|
||||||
|
commit_message = commit_match.group(1).strip()
|
||||||
|
|
||||||
|
return diff, commit_message
|
||||||
|
|
||||||
|
def apply_diff(self, file_path: str, diff: str) -> None:
|
||||||
|
with open(os.path.join(self.folder_path, file_path), 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
|
||||||
|
diff_lines = diff.split('\n')
|
||||||
|
current_line = 0
|
||||||
|
for line in diff_lines[2:]: # Skip the first two lines (--- and +++)
|
||||||
|
if line.startswith('@@'):
|
||||||
|
match = re.match(r'@@ -\d+,?\d* \+(\d+),?\d* @@', line)
|
||||||
|
if match:
|
||||||
|
current_line = int(match.group(1)) - 1
|
||||||
|
elif line.startswith('-'):
|
||||||
|
if lines[current_line].strip() == line[1:].strip():
|
||||||
|
lines.pop(current_line)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Mismatch at line {current_line + 1}")
|
||||||
|
elif line.startswith('+'):
|
||||||
|
lines.insert(current_line, line[1:] + '\n')
|
||||||
|
current_line += 1
|
||||||
|
else:
|
||||||
|
current_line += 1
|
||||||
|
|
||||||
|
with open(os.path.join(self.folder_path, file_path), 'w') as file:
|
||||||
|
file.writelines(lines)
|
||||||
|
|
||||||
|
def commit_changes(self, file_path: str, commit_message: str) -> None:
|
||||||
|
try:
|
||||||
|
subprocess.run(['git', 'add', file_path], cwd=self.folder_path, check=True)
|
||||||
|
subprocess.run(['git', 'commit', '-m', commit_message], cwd=self.folder_path, check=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to commit changes: {e}")
|
||||||
|
|
||||||
|
def modify_code(self, file_path: str, modification_instruction: str, llm_function) -> Dict[str, str]:
|
||||||
|
prompt = self.generate_llm_prompt(file_path, modification_instruction)
|
||||||
|
llm_response = llm_function(prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
diff, commit_message = self.parse_llm_response(llm_response)
|
||||||
|
print("Generated diff:")
|
||||||
|
print(diff)
|
||||||
|
self.apply_diff(file_path, diff)
|
||||||
|
self.commit_changes(file_path, commit_message)
|
||||||
|
return {"status": "success", "message": "Code modified and committed successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
def mock_llm_function(prompt):
|
||||||
|
# This is a mock function to simulate LLM response
|
||||||
|
return """
|
||||||
|
--- example.py
|
||||||
|
+++ example.py
|
||||||
|
@@ -1,5 +1,6 @@
|
||||||
|
def hello_world():
|
||||||
|
- print("Hello, World!")
|
||||||
|
+ print("Hello, Universe!")
|
||||||
|
+ return "Greetings from the cosmos!"
|
||||||
|
|
||||||
|
hello_world()
|
||||||
|
|
||||||
|
Commit message:
|
||||||
|
Update hello_world function to greet the universe and return a message
|
||||||
|
"""
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
# Usage
|
||||||
|
modifier = CodeModifier(r"C:\Users\aloui\Documents\ai\test_code_modif")
|
||||||
|
result = modifier.modify_code("example.py", "Change the greeting to 'Hello, Universe!' and make the function return a string", mock_llm_function)
|
||||||
|
print(result)
|
@ -1,29 +1,6 @@
|
|||||||
"""
|
|
||||||
project: lollms_webui
|
|
||||||
file: lollms_rag.py
|
|
||||||
author: ParisNeo
|
|
||||||
description:
|
|
||||||
This module contains a set of FastAPI routes that allow users to interact with the RAG (Retrieval-Augmented Generation) library.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
1. Initialize the RAG system by adding documents using the /add_document endpoint.
|
|
||||||
2. Build the index using the /index_database endpoint.
|
|
||||||
3. Perform searches using the /search endpoint.
|
|
||||||
4. Remove documents using the /remove_document/{document_id} endpoint.
|
|
||||||
5. Wipe the entire database using the /wipe_database endpoint.
|
|
||||||
|
|
||||||
Authentication:
|
|
||||||
- If lollms_access_keys are specified in the configuration, API key authentication is required.
|
|
||||||
- If no keys are specified, authentication is bypassed, and all users are treated as user ID 1.
|
|
||||||
|
|
||||||
User Management:
|
|
||||||
- Each user gets a unique vectorizer based on their API key.
|
|
||||||
- If no API keys are specified, all requests are treated as coming from user ID 1.
|
|
||||||
|
|
||||||
Note: Ensure proper security measures are in place when deploying this API in a production environment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, HTTPException, Depends, Header
|
from fastapi import APIRouter, Request, HTTPException, Depends, Header
|
||||||
|
import uuid
|
||||||
|
from fastapi import Request, Response
|
||||||
from lollms_webui import LOLLMSWebUI
|
from lollms_webui import LOLLMSWebUI
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
@ -34,7 +11,6 @@ from ascii_colors import ASCIIColors
|
|||||||
from lollms.databases.discussions_database import DiscussionsDB, Discussion
|
from lollms.databases.discussions_database import DiscussionsDB, Discussion
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from fastapi.security import APIKeyHeader
|
|
||||||
from lollmsvectordb.database_elements.chunk import Chunk
|
from lollmsvectordb.database_elements.chunk import Chunk
|
||||||
from lollmsvectordb.vector_database import VectorDatabase
|
from lollmsvectordb.vector_database import VectorDatabase
|
||||||
from lollmsvectordb.lollms_vectorizers.bert_vectorizer import BERTVectorizer
|
from lollmsvectordb.lollms_vectorizers.bert_vectorizer import BERTVectorizer
|
||||||
@ -53,12 +29,12 @@ import hashlib
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
lollmsElfServer: LOLLMSWebUI = LOLLMSWebUI.get_instance()
|
lollmsElfServer: LOLLMSWebUI = LOLLMSWebUI.get_instance()
|
||||||
api_key_header = APIKeyHeader(name="Authorization")
|
|
||||||
|
|
||||||
# ----------------------- RAG System ------------------------------
|
# ----------------------- RAG System ------------------------------
|
||||||
|
|
||||||
class RAGQuery(BaseModel):
|
class RAGQuery(BaseModel):
|
||||||
query: str = Field(..., description="The query to process using RAG")
|
query: str = Field(..., description="The query to process using RAG")
|
||||||
|
key: str = Field(..., description="The key to identify the user")
|
||||||
|
|
||||||
class RAGResponse(BaseModel):
|
class RAGResponse(BaseModel):
|
||||||
answer: str = Field(..., description="The generated answer")
|
answer: str = Field(..., description="The generated answer")
|
||||||
@ -68,6 +44,7 @@ class IndexDocument(BaseModel):
|
|||||||
title: str = Field(..., description="The title of the document")
|
title: str = Field(..., description="The title of the document")
|
||||||
content: str = Field(..., description="The content to be indexed")
|
content: str = Field(..., description="The content to be indexed")
|
||||||
path: str = Field(default="unknown", description="The path of the document")
|
path: str = Field(default="unknown", description="The path of the document")
|
||||||
|
key: str = Field(..., description="The key to identify the user")
|
||||||
|
|
||||||
class IndexResponse(BaseModel):
|
class IndexResponse(BaseModel):
|
||||||
success: bool = Field(..., description="Indicates if the indexing was successful")
|
success: bool = Field(..., description="Indicates if the indexing was successful")
|
||||||
@ -86,22 +63,9 @@ class RAGChunk(BaseModel):
|
|||||||
nb_tokens : int
|
nb_tokens : int
|
||||||
distance : float
|
distance : float
|
||||||
|
|
||||||
def get_user_id(bearer_key: str) -> int:
|
def get_user_vectorizer(user_key: str):
|
||||||
"""
|
small_key = hashlib.md5(user_key.encode()).hexdigest()[:8]
|
||||||
Determine the user ID based on the bearer key.
|
user_folder = lollmsElfServer.lollms_paths.personal_outputs_path / str(user_key)
|
||||||
If no keys are specified in the configuration, always return 1.
|
|
||||||
"""
|
|
||||||
if not lollmsElfServer.config.lollms_access_keys:
|
|
||||||
return 1
|
|
||||||
# Use the index of the key in the list as the user ID
|
|
||||||
try:
|
|
||||||
return lollmsElfServer.config.lollms_access_keys.index(bearer_key) + 1
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(status_code=403, detail="Invalid API Key")
|
|
||||||
|
|
||||||
def get_user_vectorizer(user_id: int, bearer_key: str):
|
|
||||||
small_key = hashlib.md5(bearer_key.encode()).hexdigest()[:8]
|
|
||||||
user_folder = lollmsElfServer.lollms_paths / str(user_id)
|
|
||||||
user_folder.mkdir(parents=True, exist_ok=True)
|
user_folder.mkdir(parents=True, exist_ok=True)
|
||||||
return VectorDatabase(
|
return VectorDatabase(
|
||||||
str(user_folder / f"rag_db_{small_key}.sqlite"),
|
str(user_folder / f"rag_db_{small_key}.sqlite"),
|
||||||
@ -111,45 +75,58 @@ def get_user_vectorizer(user_id: int, bearer_key: str):
|
|||||||
overlap=lollmsElfServer.config.rag_overlap
|
overlap=lollmsElfServer.config.rag_overlap
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_current_user(bearer_token: str = Depends(api_key_header)):
|
async def validate_key(key: str):
|
||||||
if lollmsElfServer.config.lollms_access_keys:
|
if lollmsElfServer.config.lollms_access_keys and key not in lollmsElfServer.config.lollms_access_keys:
|
||||||
if bearer_token not in lollmsElfServer.config.lollms_access_keys:
|
raise HTTPException(status_code=403, detail="Invalid Key")
|
||||||
raise HTTPException(status_code=403, detail="Invalid API Key")
|
return key
|
||||||
return bearer_token
|
|
||||||
|
|
||||||
@router.post("/add_document", response_model=DocumentResponse)
|
@router.post("/add_document", response_model=DocumentResponse)
|
||||||
async def add_document(doc: IndexDocument, user: str = Depends(get_current_user)):
|
async def add_document(doc: IndexDocument):
|
||||||
user_id = get_user_id(user)
|
await validate_key(doc.key)
|
||||||
vectorizer = get_user_vectorizer(user_id, user)
|
vectorizer = get_user_vectorizer(doc.key)
|
||||||
vectorizer.add_document(title=doc.title, text=doc.content, path=doc.path)
|
vectorizer.add_document(title=doc.title, text=doc.content, path=doc.path)
|
||||||
return DocumentResponse(success=True, message="Document added successfully.")
|
return DocumentResponse(success=True, message="Document added successfully.")
|
||||||
|
|
||||||
@router.post("/remove_document/{document_id}", response_model=DocumentResponse)
|
@router.post("/remove_document/{document_id}", response_model=DocumentResponse)
|
||||||
async def remove_document(document_id: int, user: str = Depends(get_current_user)):
|
async def remove_document(document_id: int, key: str):
|
||||||
user_id = get_user_id(user)
|
await validate_key(key)
|
||||||
vectorizer = get_user_vectorizer(user_id, user)
|
vectorizer = get_user_vectorizer(key)
|
||||||
doc_hash = vectorizer.get_document_hash(document_id)
|
doc_hash = vectorizer.get_document_hash(document_id)
|
||||||
vectorizer.remove_document(doc_hash)
|
vectorizer.remove_document(doc_hash)
|
||||||
# Logic to remove the document by ID
|
|
||||||
return DocumentResponse(success=True, message="Document removed successfully.")
|
return DocumentResponse(success=True, message="Document removed successfully.")
|
||||||
|
|
||||||
|
class IndexDatabaseRequest(BaseModel):
|
||||||
|
key: str
|
||||||
|
|
||||||
@router.post("/index_database", response_model=DocumentResponse)
|
@router.post("/index_database", response_model=DocumentResponse)
|
||||||
async def index_database(user: str = Depends(get_current_user)):
|
async def index_database(request: IndexDatabaseRequest):
|
||||||
user_id = get_user_id(user)
|
key = request.key
|
||||||
vectorizer = get_user_vectorizer(user_id, user)
|
await validate_key(key)
|
||||||
|
vectorizer = get_user_vectorizer(key)
|
||||||
vectorizer.build_index()
|
vectorizer.build_index()
|
||||||
return DocumentResponse(success=True, message="Database indexed successfully.")
|
return DocumentResponse(success=True, message="Database indexed successfully.")
|
||||||
|
|
||||||
@router.post("/search", response_model=List[RAGChunk])
|
@router.post("/search", response_model=List[RAGChunk])
|
||||||
async def search(query: RAGQuery, user: str = Depends(get_current_user)):
|
async def search(query: RAGQuery):
|
||||||
user_id = get_user_id(user)
|
await validate_key(query.key)
|
||||||
vectorizer = get_user_vectorizer(user_id, user)
|
vectorizer = get_user_vectorizer(query.key)
|
||||||
chunks = vectorizer.search(query.query)
|
chunks = vectorizer.search(query.query)
|
||||||
return [RAGChunk(c.id,c.chunk_id, c.doc.title, c.doc.path, c.text, c.nb_tokens, c.distance) for c in chunks]
|
return [
|
||||||
|
RAGChunk(
|
||||||
|
id=c.id,
|
||||||
|
chunk_id=c.chunk_id,
|
||||||
|
title=c.doc.title,
|
||||||
|
path=c.doc.path,
|
||||||
|
text=c.text,
|
||||||
|
nb_tokens=c.nb_tokens,
|
||||||
|
distance=c.distance
|
||||||
|
)
|
||||||
|
for c in chunks
|
||||||
|
]
|
||||||
|
|
||||||
@router.delete("/wipe_database", response_model=DocumentResponse)
|
@router.delete("/wipe_database", response_model=DocumentResponse)
|
||||||
async def wipe_database(user: str = Depends(get_current_user)):
|
async def wipe_database(key: str):
|
||||||
user_id = get_user_id(user)
|
await validate_key(key)
|
||||||
user_folder = lollmsElfServer.lollms_paths / str(user_id)
|
user_folder = lollmsElfServer.lollms_paths / str(key)
|
||||||
shutil.rmtree(user_folder, ignore_errors=True)
|
shutil.rmtree(user_folder, ignore_errors=True)
|
||||||
return DocumentResponse(success=True, message="Database wiped successfully.")
|
return DocumentResponse(success=True, message="Database wiped successfully.")
|
||||||
|
Loading…
Reference in New Issue
Block a user