diff --git a/lollms/security.py b/lollms/security.py index f6ac7ab..9e989ce 100644 --- a/lollms/security.py +++ b/lollms/security.py @@ -14,31 +14,87 @@ def check_access(lollmsElfServer, client_id): raise HTTPException(status_code=400, detail=f"Not accessible without id") return client -def sanitize_code(code): + +def sanitize_based_on_separators(line): + """ + Sanitizes a line of code based on common command separators. + + Parameters: + - line (str): The line of code to be sanitized. + + Returns: + - str: The sanitized line of code. + """ + separators = ['&', '|', ';'] + for sep in separators: + if sep in line: + line = line.split(sep)[0] # Keep only the first command before the separator + break + return line.strip() + +def sanitize_after_whitelisted_command(line, command): + """ + Sanitizes the line after a whitelisted command, removing any following commands + if a command separator is present. + + Parameters: + - line (str): The line of code containing the whitelisted command. + - command (str): The whitelisted command. + + Returns: + - str: The sanitized line of code, ensuring only the whitelisted command is executed. + """ + # Find the end of the whitelisted command in the line + command_end_index = line.find(command) + len(command) + # Extract the rest of the line after the whitelisted command + rest_of_line = line[command_end_index:] + # Sanitize the rest of the line based on separators + sanitized_rest = sanitize_based_on_separators(rest_of_line) + # If anything malicious was removed, sanitized_rest will be empty, so only return the whitelisted command part + if not sanitized_rest: + return line[:command_end_index].strip() + else: + # If rest_of_line starts directly with separators followed by malicious commands, sanitized_rest will be empty + # This means we should only return the part up to the whitelisted command + return line[:command_end_index + len(sanitized_rest)].strip() + + +def sanitize_shell_code(code, whitelist=None): + """ + Securely sanitizes a block of code by allowing commands from a provided whitelist, + but only up to the first command separator if followed by other commands. + Sanitizes based on common command separators if no whitelist is provided. + + Parameters: + - code (str): The input code to be sanitized. + - whitelist (list): Optional. A list of whitelisted commands that are allowed. + + Returns: + - str: The securely sanitized code. + """ + # Split the code by newline characters lines = code.split('\n') - # Keep only the first non-empty line and remove any potential malicious commands + # Initialize the sanitized code variable sanitized_code = "" for line in lines: if line.strip(): # Check if the line is not empty - # Check for potential malicious commands - if platform.system() == "Windows": - if "&" in line: - line = line.split("&")[0] # Keep only the first command before the ampersand - if "|" in line: - line = line.split("|")[0] # Keep only the first command before the pipe - else: # Linux - if ";" in line: - line = line.split(";")[0] # Keep only the first command before the semicolon - if "|" in line: - line = line.split("|")[0] # Keep only the first command before the pipe - sanitized_code = line - break + if whitelist: + for command in whitelist: + if line.strip().startswith(command): + # Check for command separators after the whitelisted command + sanitized_code = sanitize_after_whitelisted_command(line, command) + break + else: + # Sanitize based on separators if no whitelist is provided + sanitized_code = sanitize_based_on_separators(line) + break # Only process the first non-empty line return sanitized_code + def sanitize_path(path:str, allow_absolute_path:bool=False, error_text="Absolute database path detected", exception_text="Detected an attempt of path traversal. Are you kidding me?"): if not allow_absolute_path and path.strip().startswith("/"): raise HTTPException(status_code=400, detail=exception_text) diff --git a/lollms/utilities.py b/lollms/utilities.py index 80e1c67..c57f4d7 100644 --- a/lollms/utilities.py +++ b/lollms/utilities.py @@ -38,13 +38,13 @@ import git import mimetypes import subprocess -from lollms.security import sanitize_code +from lollms.security import sanitize_shell_code from functools import partial def create_conda_env(env_name, python_version): - env_name = sanitize_code(env_name) - python_version = sanitize_code(python_version) + env_name = sanitize_shell_code(env_name) + python_version = sanitize_shell_code(python_version) # Activate the Conda environment import platform if platform.system()=="Windows":