diff --git a/lollms/security.py b/lollms/security.py index f5f3fa4..9374fdb 100644 --- a/lollms/security.py +++ b/lollms/security.py @@ -65,31 +65,56 @@ def sanitize_after_whitelisted_command(line, command): # This means we should only return the part up to the whitelisted command return line[:command_end_index + len(sanitized_rest)].strip() +if not(PackageManager.check_package_installed("defusedxml")): + PackageManager.install_or_update("defusedxml") + +import defusedxml.ElementTree as ET + +from defusedxml import ElementTree as ET +from io import StringIO def sanitize_svg(svg_content): try: - parser = ET.XMLParser(remove_comments=True, remove_pis=True) - tree = ET.fromstring(svg_content, parser=parser) + # Use defusedxml's parse function with a StringIO object + tree = ET.parse(StringIO(svg_content)) + root = tree.getroot() - # Remove any script elements - for script in tree.xpath('//svg:script', namespaces={'svg': 'http://www.w3.org/2000/svg'}): - parent = script.getparent() - if parent is not None: - parent.remove(script) + # Define a list of allowed elements + allowed_elements = { + 'svg', 'g', 'path', 'circle', 'rect', 'line', 'polyline', 'polygon', + 'text', 'tspan', 'defs', 'filter', 'feGaussianBlur', 'feMerge', + 'feMergeNode', 'linearGradient', 'radialGradient', 'stop' + } - # Remove any 'on*' event attributes - for element in tree.xpath('//*[@*[starts-with(name(), "on")]]'): + # Define a list of allowed attributes + allowed_attributes = { + 'id', 'class', 'style', 'fill', 'stroke', 'stroke-width', 'cx', 'cy', + 'r', 'x', 'y', 'width', 'height', 'd', 'transform', 'viewBox', + 'xmlns', 'xmlns:xlink', 'version', 'stdDeviation', 'result', 'in', + 'x1', 'y1', 'x2', 'y2', 'offset', 'stop-color', 'stop-opacity' + } + + # Remove any disallowed elements + for element in root.iter(): + if element.tag.split('}')[-1] not in allowed_elements: + parent = element.getparent() + if parent is not None: + parent.remove(element) + + # Remove any disallowed attributes + for element in root.iter(): for attr in list(element.attrib): - if attr.startswith('on'): + if attr not in allowed_attributes: del element.attrib[attr] # Convert the tree back to an SVG string - sanitized_svg = ET.tostring(tree, encoding='unicode', method='xml') + sanitized_svg = ET.tostring(root, encoding='unicode', method='xml') return sanitized_svg - except ET.XMLSyntaxError as e: + except ET.ParseError as e: raise ValueError("Invalid SVG content") from e + def sanitize_shell_code(code, whitelist=None): """ Securely sanitizes a block of code by allowing commands from a provided whitelist,