mirror of
https://github.com/ParisNeo/lollms.git
synced 2025-03-01 04:06:07 +00:00
security fix
This commit is contained in:
parent
9942d177a1
commit
6b0675d5fe
@ -12,7 +12,7 @@ from lollms.utilities import PackageManager
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import Request
|
||||
|
||||
import asyncio
|
||||
|
||||
if not PackageManager.check_package_installed("lxml"):
|
||||
PackageManager.install_package("lxml")
|
||||
@ -338,7 +338,17 @@ if __name__=="__main__":
|
||||
print(f"Original: {path}, Sanitized: {sanitized}")
|
||||
except HTTPException as e:
|
||||
print(f"Original: {path}, Exception: {e.detail}")
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from fastapi import Request
|
||||
from starlette.responses import JSONResponse
|
||||
import re
|
||||
|
||||
class MultipartBoundaryCheck(BaseHTTPMiddleware):
|
||||
def __init__(self, app, max_boundary_length=70, max_trailing_boundary_length=100):
|
||||
super().__init__(app)
|
||||
self.max_boundary_length = max_boundary_length
|
||||
self.max_trailing_boundary_length = max_trailing_boundary_length
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.headers.get("content-type", "").startswith("multipart/form-data"):
|
||||
content_type = request.headers.get("content-type", "")
|
||||
@ -350,19 +360,22 @@ class MultipartBoundaryCheck(BaseHTTPMiddleware):
|
||||
boundary = content_type[boundary_start + 9:] # 9 is the length of "boundary="
|
||||
|
||||
# Check header boundary
|
||||
if len(boundary) > 70 or not self.is_valid_boundary(boundary):
|
||||
if len(boundary) > self.max_boundary_length or not self.is_valid_boundary(boundary):
|
||||
return JSONResponse(status_code=400, content={"detail": "Invalid boundary in header"})
|
||||
|
||||
# Read and check the request body
|
||||
# Check trailing boundary if it exists
|
||||
body = await request.body()
|
||||
body_str = body.decode(errors='ignore')
|
||||
trailing_boundary = b"--" + boundary.encode() + b"--"
|
||||
if trailing_boundary in body:
|
||||
trailing_content = body[body.rfind(trailing_boundary):]
|
||||
if len(trailing_content) > self.max_trailing_boundary_length:
|
||||
return JSONResponse(status_code=400, content={"detail": "Trailing boundary too long"})
|
||||
|
||||
# Check for excessively long or invalid boundaries in the body
|
||||
pattern = re.escape(f"--{boundary}") + r"(?:--)?$"
|
||||
matches = re.findall(pattern, body_str, re.MULTILINE)
|
||||
for match in matches:
|
||||
if len(match) > 74 or not self.is_valid_boundary(match[2:].rstrip('-')):
|
||||
return JSONResponse(status_code=400, content={"detail": "Invalid boundary in body"})
|
||||
# Check for multiple trailing boundaries
|
||||
if body.count(trailing_boundary) > 1:
|
||||
return JSONResponse(status_code=400, content={"detail": "Multiple trailing boundaries detected"})
|
||||
|
||||
# Note: We're not returning an error if there's no trailing boundary
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user