From b5dc256323694069acc3b9414a451f8380a73b43 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Wed, 9 Oct 2024 08:51:27 +0200 Subject: [PATCH] Fixed DOS vulenerability on boudaries in multipart chunks --- lollms/security.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/lollms/security.py b/lollms/security.py index e6ef812..7433d58 100644 --- a/lollms/security.py +++ b/lollms/security.py @@ -11,6 +11,8 @@ import string from lollms.utilities import PackageManager from starlette.middleware.base import BaseHTTPMiddleware from fastapi.responses import JSONResponse +from fastapi import Request + if not PackageManager.check_package_installed("lxml"): PackageManager.install_package("lxml") @@ -338,9 +340,21 @@ if __name__=="__main__": print(f"Original: {path}, Exception: {e.detail}") class MultipartBoundaryCheck(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch(self, request: Request, call_next): if request.headers.get("content-type", "").startswith("multipart/form-data"): boundary = request.headers.get("content-type").split("boundary=")[-1] - if len(boundary) > 70: # Adjust this limit as needed - return JSONResponse(status_code=400, content={"detail": "Invalid boundary length"}) + if len(boundary) > 70: # Check header boundary + return JSONResponse(status_code=400, content={"detail": "Invalid boundary length in header"}) + + # Read and check the request body + body = await request.body() + body_str = body.decode() + + # Check for excessively long boundaries in the body + pattern = re.escape(f"--{boundary}") + r"-*" + matches = re.findall(pattern, body_str) + for match in matches: + if len(match) > 74: # 70 + 4 for '--' prefix and '--' suffix + return JSONResponse(status_code=400, content={"detail": "Invalid boundary length in body"}) + return await call_next(request)