security fix

This commit is contained in:
Saifeddine ALOUI 2024-11-12 22:28:15 +01:00
parent 9942d177a1
commit 6b0675d5fe

View File

@ -12,7 +12,7 @@ from lollms.utilities import PackageManager
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.responses import JSONResponse
from fastapi import Request
import asyncio
if not PackageManager.check_package_installed("lxml"):
PackageManager.install_package("lxml")
@ -338,7 +338,17 @@ if __name__=="__main__":
print(f"Original: {path}, Sanitized: {sanitized}")
except HTTPException as e:
print(f"Original: {path}, Exception: {e.detail}")
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import Request
from starlette.responses import JSONResponse
import re
class MultipartBoundaryCheck(BaseHTTPMiddleware):
def __init__(self, app, max_boundary_length=70, max_trailing_boundary_length=100):
super().__init__(app)
self.max_boundary_length = max_boundary_length
self.max_trailing_boundary_length = max_trailing_boundary_length
async def dispatch(self, request: Request, call_next):
if request.headers.get("content-type", "").startswith("multipart/form-data"):
content_type = request.headers.get("content-type", "")
@ -350,19 +360,22 @@ class MultipartBoundaryCheck(BaseHTTPMiddleware):
boundary = content_type[boundary_start + 9:] # 9 is the length of "boundary="
# Check header boundary
if len(boundary) > 70 or not self.is_valid_boundary(boundary):
if len(boundary) > self.max_boundary_length or not self.is_valid_boundary(boundary):
return JSONResponse(status_code=400, content={"detail": "Invalid boundary in header"})
# Read and check the request body
# Check trailing boundary if it exists
body = await request.body()
body_str = body.decode(errors='ignore')
trailing_boundary = b"--" + boundary.encode() + b"--"
if trailing_boundary in body:
trailing_content = body[body.rfind(trailing_boundary):]
if len(trailing_content) > self.max_trailing_boundary_length:
return JSONResponse(status_code=400, content={"detail": "Trailing boundary too long"})
# Check for excessively long or invalid boundaries in the body
pattern = re.escape(f"--{boundary}") + r"(?:--)?$"
matches = re.findall(pattern, body_str, re.MULTILINE)
for match in matches:
if len(match) > 74 or not self.is_valid_boundary(match[2:].rstrip('-')):
return JSONResponse(status_code=400, content={"detail": "Invalid boundary in body"})
# Check for multiple trailing boundaries
if body.count(trailing_boundary) > 1:
return JSONResponse(status_code=400, content={"detail": "Multiple trailing boundaries detected"})
# Note: We're not returning an error if there's no trailing boundary
return await call_next(request)