mirror of
https://github.com/GNS3/gns3-server.git
synced 2025-01-08 22:12:41 +00:00
Merge pull request #2070 from GNS3/project-export-zstd
zstandard compression support for project export
This commit is contained in:
commit
466aaf5c13
@ -21,10 +21,10 @@ API routes for projects.
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
import gns3server.utils.zipfile_zstd as zipfile
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ from pathlib import Path
|
|||||||
from gns3server import schemas
|
from gns3server import schemas
|
||||||
from gns3server.controller import Controller
|
from gns3server.controller import Controller
|
||||||
from gns3server.controller.project import Project
|
from gns3server.controller.project import Project
|
||||||
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
|
from gns3server.controller.controller_error import ControllerError, ControllerBadRequestError
|
||||||
from gns3server.controller.import_project import import_project as import_controller_project
|
from gns3server.controller.import_project import import_project as import_controller_project
|
||||||
from gns3server.controller.export_project import export_project as export_controller_project
|
from gns3server.controller.export_project import export_project as export_controller_project
|
||||||
from gns3server.utils.asyncio import aiozipstream
|
from gns3server.utils.asyncio import aiozipstream
|
||||||
@ -285,7 +285,8 @@ async def export_project(
|
|||||||
include_snapshots: bool = False,
|
include_snapshots: bool = False,
|
||||||
include_images: bool = False,
|
include_images: bool = False,
|
||||||
reset_mac_addresses: bool = False,
|
reset_mac_addresses: bool = False,
|
||||||
compression: str = "zip",
|
compression: schemas.ProjectCompression = "zstd",
|
||||||
|
compression_level: int = None,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""
|
"""
|
||||||
Export a project as a portable archive.
|
Export a project as a portable archive.
|
||||||
@ -294,12 +295,23 @@ async def export_project(
|
|||||||
compression_query = compression.lower()
|
compression_query = compression.lower()
|
||||||
if compression_query == "zip":
|
if compression_query == "zip":
|
||||||
compression = zipfile.ZIP_DEFLATED
|
compression = zipfile.ZIP_DEFLATED
|
||||||
|
if compression_level is not None and (compression_level < 0 or compression_level > 9):
|
||||||
|
raise ControllerBadRequestError("Compression level must be between 0 and 9 for ZIP compression")
|
||||||
elif compression_query == "none":
|
elif compression_query == "none":
|
||||||
compression = zipfile.ZIP_STORED
|
compression = zipfile.ZIP_STORED
|
||||||
elif compression_query == "bzip2":
|
elif compression_query == "bzip2":
|
||||||
compression = zipfile.ZIP_BZIP2
|
compression = zipfile.ZIP_BZIP2
|
||||||
|
if compression_level is not None and (compression_level < 1 or compression_level > 9):
|
||||||
|
raise ControllerBadRequestError("Compression level must be between 1 and 9 for BZIP2 compression")
|
||||||
elif compression_query == "lzma":
|
elif compression_query == "lzma":
|
||||||
compression = zipfile.ZIP_LZMA
|
compression = zipfile.ZIP_LZMA
|
||||||
|
elif compression_query == "zstd":
|
||||||
|
compression = zipfile.ZIP_ZSTANDARD
|
||||||
|
if compression_level is not None and (compression_level < 1 or compression_level > 22):
|
||||||
|
raise ControllerBadRequestError("Compression level must be between 1 and 22 for Zstandard compression")
|
||||||
|
|
||||||
|
if compression_level is not None and compression_query in ("none", "lzma"):
|
||||||
|
raise ControllerBadRequestError(f"Compression level is not supported for '{compression_query}' compression method")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
begin = time.time()
|
begin = time.time()
|
||||||
@ -307,8 +319,10 @@ async def export_project(
|
|||||||
working_dir = os.path.abspath(os.path.join(project.path, os.pardir))
|
working_dir = os.path.abspath(os.path.join(project.path, os.pardir))
|
||||||
|
|
||||||
async def streamer():
|
async def streamer():
|
||||||
|
log.info(f"Exporting project '{project.name}' with '{compression_query}' compression "
|
||||||
|
f"(level {compression_level})")
|
||||||
with tempfile.TemporaryDirectory(dir=working_dir) as tmpdir:
|
with tempfile.TemporaryDirectory(dir=working_dir) as tmpdir:
|
||||||
with aiozipstream.ZipFile(compression=compression) as zstream:
|
with aiozipstream.ZipFile(compression=compression, compresslevel=compression_level) as zstream:
|
||||||
await export_controller_project(
|
await export_controller_project(
|
||||||
zstream,
|
zstream,
|
||||||
project,
|
project,
|
||||||
|
@ -166,12 +166,14 @@ async def sqlalchemry_error_handler(request: Request, exc: SQLAlchemyError):
|
|||||||
content={"message": "Database error detected, please check logs to find details"},
|
content={"message": "Database error detected, please check logs to find details"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: do not use this middleware since it creates issue when using StreamingResponse
|
||||||
|
# see https://starlette-context.readthedocs.io/en/latest/middleware.html#why-are-there-two-middlewares-that-do-the-same-thing
|
||||||
|
|
||||||
@app.middleware("http")
|
# @app.middleware("http")
|
||||||
async def add_extra_headers(request: Request, call_next):
|
# async def add_extra_headers(request: Request, call_next):
|
||||||
start_time = time.time()
|
# start_time = time.time()
|
||||||
response = await call_next(request)
|
# response = await call_next(request)
|
||||||
process_time = time.time() - start_time
|
# process_time = time.time() - start_time
|
||||||
response.headers["X-Process-Time"] = str(process_time)
|
# response.headers["X-Process-Time"] = str(process_time)
|
||||||
response.headers["X-GNS3-Server-Version"] = f"{__version__}"
|
# response.headers["X-GNS3-Server-Version"] = f"{__version__}"
|
||||||
return response
|
# return response
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
@ -20,10 +20,10 @@ import sys
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import shutil
|
import shutil
|
||||||
import zipfile
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import itertools
|
import itertools
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import gns3server.utils.zipfile_zstd as zipfile_zstd
|
||||||
|
|
||||||
from .controller_error import ControllerError
|
from .controller_error import ControllerError
|
||||||
from .topology import load_topology
|
from .topology import load_topology
|
||||||
@ -60,9 +60,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non
|
|||||||
raise ControllerError("The destination path should not contain .gns3")
|
raise ControllerError("The destination path should not contain .gns3")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(stream) as zip_file:
|
with zipfile_zstd.ZipFile(stream) as zip_file:
|
||||||
project_file = zip_file.read("project.gns3").decode()
|
project_file = zip_file.read("project.gns3").decode()
|
||||||
except zipfile.BadZipFile:
|
except zipfile_zstd.BadZipFile:
|
||||||
raise ControllerError("Cannot import project, not a GNS3 project (invalid zip)")
|
raise ControllerError("Cannot import project, not a GNS3 project (invalid zip)")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ControllerError("Cannot import project, project.gns3 file could not be found")
|
raise ControllerError("Cannot import project, project.gns3 file could not be found")
|
||||||
@ -92,9 +92,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non
|
|||||||
raise ControllerError("The project name contain non supported or invalid characters")
|
raise ControllerError("The project name contain non supported or invalid characters")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(stream) as zip_file:
|
with zipfile_zstd.ZipFile(stream) as zip_file:
|
||||||
await wait_run_in_executor(zip_file.extractall, path)
|
await wait_run_in_executor(zip_file.extractall, path)
|
||||||
except zipfile.BadZipFile:
|
except zipfile_zstd.BadZipFile:
|
||||||
raise ControllerError("Cannot extract files from GNS3 project (invalid zip)")
|
raise ControllerError("Cannot extract files from GNS3 project (invalid zip)")
|
||||||
|
|
||||||
topology = load_topology(os.path.join(path, "project.gns3"))
|
topology = load_topology(os.path.join(path, "project.gns3"))
|
||||||
@ -264,11 +264,11 @@ async def _import_snapshots(snapshots_path, project_name, project_id):
|
|||||||
# extract everything to a temporary directory
|
# extract everything to a temporary directory
|
||||||
try:
|
try:
|
||||||
with open(snapshot_path, "rb") as f:
|
with open(snapshot_path, "rb") as f:
|
||||||
with zipfile.ZipFile(f) as zip_file:
|
with zipfile_zstd.ZipFile(f) as zip_file:
|
||||||
await wait_run_in_executor(zip_file.extractall, tmpdir)
|
await wait_run_in_executor(zip_file.extractall, tmpdir)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise ControllerError(f"Cannot open snapshot '{os.path.basename(snapshot)}': {e}")
|
raise ControllerError(f"Cannot open snapshot '{os.path.basename(snapshot)}': {e}")
|
||||||
except zipfile.BadZipFile:
|
except zipfile_zstd.BadZipFile:
|
||||||
raise ControllerError(
|
raise ControllerError(
|
||||||
f"Cannot extract files from snapshot '{os.path.basename(snapshot)}': not a GNS3 project (invalid zip)"
|
f"Cannot extract files from snapshot '{os.path.basename(snapshot)}': not a GNS3 project (invalid zip)"
|
||||||
)
|
)
|
||||||
@ -294,7 +294,7 @@ async def _import_snapshots(snapshots_path, project_name, project_id):
|
|||||||
|
|
||||||
# write everything back to the original snapshot file
|
# write everything back to the original snapshot file
|
||||||
try:
|
try:
|
||||||
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
|
with aiozipstream.ZipFile(compression=zipfile_zstd.ZIP_STORED) as zstream:
|
||||||
for root, dirs, files in os.walk(tmpdir, topdown=True, followlinks=False):
|
for root, dirs, files in os.walk(tmpdir, topdown=True, followlinks=False):
|
||||||
for file in files:
|
for file in files:
|
||||||
path = os.path.join(root, file)
|
path = os.path.join(root, file)
|
||||||
|
@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance
|
|||||||
from .controller.drawings import Drawing
|
from .controller.drawings import Drawing
|
||||||
from .controller.gns3vm import GNS3VM
|
from .controller.gns3vm import GNS3VM
|
||||||
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
|
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
|
||||||
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
|
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression
|
||||||
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
|
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
|
||||||
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
|
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
|
||||||
from .controller.tokens import Token
|
from .controller.tokens import Token
|
||||||
|
@ -102,3 +102,15 @@ class ProjectFile(BaseModel):
|
|||||||
|
|
||||||
path: str = Field(..., description="File path")
|
path: str = Field(..., description="File path")
|
||||||
md5sum: str = Field(..., description="File checksum")
|
md5sum: str = Field(..., description="File checksum")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectCompression(str, Enum):
|
||||||
|
"""
|
||||||
|
Supported project compression.
|
||||||
|
"""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
zip = "zip"
|
||||||
|
bzip2 = "bzip2"
|
||||||
|
lzma = "lzma"
|
||||||
|
zstd = "zstd"
|
||||||
|
@ -43,26 +43,38 @@ from zipfile import (
|
|||||||
stringEndArchive64Locator,
|
stringEndArchive64Locator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd
|
||||||
|
ZSTANDARD_VERSION = 20
|
||||||
stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor
|
stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor
|
||||||
|
|
||||||
|
|
||||||
def _get_compressor(compress_type):
|
def _get_compressor(compress_type, compresslevel=None):
|
||||||
"""
|
"""
|
||||||
Return the compressor.
|
Return the compressor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if compress_type == zipfile.ZIP_DEFLATED:
|
if compress_type == zipfile.ZIP_DEFLATED:
|
||||||
from zipfile import zlib
|
from zipfile import zlib
|
||||||
|
if compresslevel is not None:
|
||||||
|
return zlib.compressobj(compresslevel, zlib.DEFLATED, -15)
|
||||||
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15)
|
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15)
|
||||||
elif compress_type == zipfile.ZIP_BZIP2:
|
elif compress_type == zipfile.ZIP_BZIP2:
|
||||||
from zipfile import bz2
|
from zipfile import bz2
|
||||||
|
if compresslevel is not None:
|
||||||
|
return bz2.BZ2Compressor(compresslevel)
|
||||||
return bz2.BZ2Compressor()
|
return bz2.BZ2Compressor()
|
||||||
|
# compresslevel is ignored for ZIP_LZMA
|
||||||
elif compress_type == zipfile.ZIP_LZMA:
|
elif compress_type == zipfile.ZIP_LZMA:
|
||||||
from zipfile import LZMACompressor
|
from zipfile import LZMACompressor
|
||||||
|
|
||||||
return LZMACompressor()
|
return LZMACompressor()
|
||||||
|
elif compress_type == ZIP_ZSTANDARD:
|
||||||
|
import zstandard as zstd
|
||||||
|
if compresslevel is not None:
|
||||||
|
#params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31)
|
||||||
|
#return zstd.ZstdCompressor(compression_params=params).compressobj()
|
||||||
|
return zstd.ZstdCompressor(level=compresslevel).compressobj()
|
||||||
|
return zstd.ZstdCompressor().compressobj()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -129,7 +141,15 @@ class ZipInfo(zipfile.ZipInfo):
|
|||||||
|
|
||||||
|
|
||||||
class ZipFile(zipfile.ZipFile):
|
class ZipFile(zipfile.ZipFile):
|
||||||
def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768):
|
def __init__(
|
||||||
|
self,
|
||||||
|
fileobj=None,
|
||||||
|
mode="w",
|
||||||
|
compression=zipfile.ZIP_STORED,
|
||||||
|
allowZip64=True,
|
||||||
|
compresslevel=None,
|
||||||
|
chunksize=32768
|
||||||
|
):
|
||||||
"""Open the ZIP file with mode write "w"."""
|
"""Open the ZIP file with mode write "w"."""
|
||||||
|
|
||||||
if mode not in ("w",):
|
if mode not in ("w",):
|
||||||
@ -138,7 +158,13 @@ class ZipFile(zipfile.ZipFile):
|
|||||||
fileobj = PointerIO()
|
fileobj = PointerIO()
|
||||||
|
|
||||||
self._comment = b""
|
self._comment = b""
|
||||||
zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64)
|
zipfile.ZipFile.__init__(
|
||||||
|
self, fileobj,
|
||||||
|
mode=mode,
|
||||||
|
compression=compression,
|
||||||
|
compresslevel=compresslevel,
|
||||||
|
allowZip64=allowZip64
|
||||||
|
)
|
||||||
self._chunksize = chunksize
|
self._chunksize = chunksize
|
||||||
self.paths_to_write = []
|
self.paths_to_write = []
|
||||||
|
|
||||||
@ -195,23 +221,33 @@ class ZipFile(zipfile.ZipFile):
|
|||||||
for chunk in self._close():
|
for chunk in self._close():
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def write(self, filename, arcname=None, compress_type=None):
|
def write(self, filename, arcname=None, compress_type=None, compresslevel=None):
|
||||||
"""
|
"""
|
||||||
Write a file to the archive under the name `arcname`.
|
Write a file to the archive under the name `arcname`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type}
|
kwargs = {
|
||||||
|
"filename": filename,
|
||||||
|
"arcname": arcname,
|
||||||
|
"compress_type": compress_type,
|
||||||
|
"compresslevel": compresslevel
|
||||||
|
}
|
||||||
self.paths_to_write.append(kwargs)
|
self.paths_to_write.append(kwargs)
|
||||||
|
|
||||||
def write_iter(self, arcname, iterable, compress_type=None):
|
def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None):
|
||||||
"""
|
"""
|
||||||
Write the bytes iterable `iterable` to the archive under the name `arcname`.
|
Write the bytes iterable `iterable` to the archive under the name `arcname`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type}
|
kwargs = {
|
||||||
|
"arcname": arcname,
|
||||||
|
"iterable": iterable,
|
||||||
|
"compress_type": compress_type,
|
||||||
|
"compresslevel": compresslevel
|
||||||
|
}
|
||||||
self.paths_to_write.append(kwargs)
|
self.paths_to_write.append(kwargs)
|
||||||
|
|
||||||
def writestr(self, arcname, data, compress_type=None):
|
def writestr(self, arcname, data, compress_type=None, compresslevel=None):
|
||||||
"""
|
"""
|
||||||
Writes a str into ZipFile by wrapping data as a generator
|
Writes a str into ZipFile by wrapping data as a generator
|
||||||
"""
|
"""
|
||||||
@ -219,9 +255,9 @@ class ZipFile(zipfile.ZipFile):
|
|||||||
def _iterable():
|
def _iterable():
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
return self.write_iter(arcname, _iterable(), compress_type=compress_type)
|
return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel)
|
||||||
|
|
||||||
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None):
|
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None):
|
||||||
"""
|
"""
|
||||||
Put the bytes from filename into the archive under the name `arcname`.
|
Put the bytes from filename into the archive under the name `arcname`.
|
||||||
"""
|
"""
|
||||||
@ -256,6 +292,11 @@ class ZipFile(zipfile.ZipFile):
|
|||||||
else:
|
else:
|
||||||
zinfo.compress_type = compress_type
|
zinfo.compress_type = compress_type
|
||||||
|
|
||||||
|
if compresslevel is None:
|
||||||
|
zinfo._compresslevel = self.compresslevel
|
||||||
|
else:
|
||||||
|
zinfo._compresslevel = compresslevel
|
||||||
|
|
||||||
if st:
|
if st:
|
||||||
zinfo.file_size = st[6]
|
zinfo.file_size = st[6]
|
||||||
else:
|
else:
|
||||||
@ -279,7 +320,7 @@ class ZipFile(zipfile.ZipFile):
|
|||||||
yield self.fp.write(zinfo.FileHeader(False))
|
yield self.fp.write(zinfo.FileHeader(False))
|
||||||
return
|
return
|
||||||
|
|
||||||
cmpr = _get_compressor(zinfo.compress_type)
|
cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel)
|
||||||
|
|
||||||
# Must overwrite CRC and sizes with correct data later
|
# Must overwrite CRC and sizes with correct data later
|
||||||
zinfo.CRC = CRC = 0
|
zinfo.CRC = CRC = 0
|
||||||
@ -369,6 +410,8 @@ class ZipFile(zipfile.ZipFile):
|
|||||||
min_version = max(zipfile.BZIP2_VERSION, min_version)
|
min_version = max(zipfile.BZIP2_VERSION, min_version)
|
||||||
elif zinfo.compress_type == zipfile.ZIP_LZMA:
|
elif zinfo.compress_type == zipfile.ZIP_LZMA:
|
||||||
min_version = max(zipfile.LZMA_VERSION, min_version)
|
min_version = max(zipfile.LZMA_VERSION, min_version)
|
||||||
|
elif zinfo.compress_type == ZIP_ZSTANDARD:
|
||||||
|
min_version = max(ZSTANDARD_VERSION, min_version)
|
||||||
|
|
||||||
extract_version = max(min_version, zinfo.extract_version)
|
extract_version = max(min_version, zinfo.extract_version)
|
||||||
create_version = max(min_version, zinfo.create_version)
|
create_version = max(min_version, zinfo.create_version)
|
||||||
|
10
gns3server/utils/zipfile_zstd/__init__.py
Normal file
10
gns3server/utils/zipfile_zstd/__init__.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
|
||||||
|
# NOTE: this patches the standard zipfile module
|
||||||
|
from . import _zipfile
|
||||||
|
|
||||||
|
from zipfile import *
|
||||||
|
from zipfile import (
|
||||||
|
ZIP_ZSTANDARD,
|
||||||
|
ZSTANDARD_VERSION,
|
||||||
|
)
|
||||||
|
|
20
gns3server/utils/zipfile_zstd/_patcher.py
Normal file
20
gns3server/utils/zipfile_zstd/_patcher.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
class patch:
|
||||||
|
|
||||||
|
originals = {}
|
||||||
|
|
||||||
|
def __init__(self, host, name):
|
||||||
|
self.host = host
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __call__(self, func):
|
||||||
|
original = getattr(self.host, self.name)
|
||||||
|
self.originals[self.name] = original
|
||||||
|
|
||||||
|
functools.update_wrapper(func, original)
|
||||||
|
setattr(self.host, self.name, func)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
64
gns3server/utils/zipfile_zstd/_zipfile.py
Normal file
64
gns3server/utils/zipfile_zstd/_zipfile.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import zipfile
|
||||||
|
import zstandard as zstd
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from ._patcher import patch
|
||||||
|
|
||||||
|
|
||||||
|
zipfile.ZIP_ZSTANDARD = 93
|
||||||
|
zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard'
|
||||||
|
zipfile.ZSTANDARD_VERSION = 20
|
||||||
|
|
||||||
|
|
||||||
|
@patch(zipfile, '_check_compression')
|
||||||
|
def zstd_check_compression(compression):
|
||||||
|
if compression == zipfile.ZIP_ZSTANDARD:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
patch.originals['_check_compression'](compression)
|
||||||
|
|
||||||
|
|
||||||
|
class ZstdDecompressObjWrapper:
|
||||||
|
def __init__(self, o):
|
||||||
|
self.o = o
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr == 'eof':
|
||||||
|
return False
|
||||||
|
return getattr(self.o, attr)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(zipfile, '_get_decompressor')
|
||||||
|
def zstd_get_decompressor(compress_type):
|
||||||
|
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||||
|
return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj())
|
||||||
|
else:
|
||||||
|
return patch.originals['_get_decompressor'](compress_type)
|
||||||
|
|
||||||
|
|
||||||
|
if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters:
|
||||||
|
@patch(zipfile, '_get_compressor')
|
||||||
|
def zstd_get_compressor(compress_type, compresslevel=None):
|
||||||
|
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||||
|
if compresslevel is None:
|
||||||
|
compresslevel = 3
|
||||||
|
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
|
||||||
|
else:
|
||||||
|
return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel)
|
||||||
|
else:
|
||||||
|
@patch(zipfile, '_get_compressor')
|
||||||
|
def zstd_get_compressor(compress_type, compresslevel=None):
|
||||||
|
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||||
|
if compresslevel is None:
|
||||||
|
compresslevel = 3
|
||||||
|
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
|
||||||
|
else:
|
||||||
|
return patch.originals['_get_compressor'](compress_type)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(zipfile.ZipInfo, 'FileHeader')
|
||||||
|
def zstd_FileHeader(self, zip64=None):
|
||||||
|
if self.compress_type == zipfile.ZIP_ZSTANDARD:
|
||||||
|
self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION)
|
||||||
|
self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION)
|
||||||
|
return patch.originals['FileHeader'](self, zip64=zip64)
|
@ -16,4 +16,5 @@ passlib[bcrypt]==1.7.4
|
|||||||
python-jose==3.3.0
|
python-jose==3.3.0
|
||||||
email-validator==1.2.1
|
email-validator==1.2.1
|
||||||
watchfiles==0.14.1
|
watchfiles==0.14.1
|
||||||
|
zstandard==0.17.0
|
||||||
setuptools==60.6.0 # don't upgrade because of https://github.com/pypa/setuptools/issues/3084
|
setuptools==60.6.0 # don't upgrade because of https://github.com/pypa/setuptools/issues/3084
|
||||||
|
@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
import zipfile
|
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -26,6 +25,7 @@ from httpx import AsyncClient
|
|||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
from tests.utils import asyncio_patch
|
from tests.utils import asyncio_patch
|
||||||
|
|
||||||
|
import gns3server.utils.zipfile_zstd as zipfile_zstd
|
||||||
from gns3server.controller import Controller
|
from gns3server.controller import Controller
|
||||||
from gns3server.controller.project import Project
|
from gns3server.controller.project import Project
|
||||||
|
|
||||||
@ -261,7 +261,7 @@ async def test_export_with_images(app: FastAPI, client: AsyncClient, tmpdir, pro
|
|||||||
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||||
with myzip.open("a") as myfile:
|
with myzip.open("a") as myfile:
|
||||||
content = myfile.read()
|
content = myfile.read()
|
||||||
assert content == b"hello"
|
assert content == b"hello"
|
||||||
@ -304,7 +304,7 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir,
|
|||||||
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||||
with myzip.open("a") as myfile:
|
with myzip.open("a") as myfile:
|
||||||
content = myfile.read()
|
content = myfile.read()
|
||||||
assert content == b"hello"
|
assert content == b"hello"
|
||||||
@ -313,6 +313,67 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir,
|
|||||||
myzip.getinfo("images/IOS/test.image")
|
myzip.getinfo("images/IOS/test.image")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"compression, compression_level, status_code",
|
||||||
|
(
|
||||||
|
("none", None, status.HTTP_200_OK),
|
||||||
|
("none", 4, status.HTTP_400_BAD_REQUEST),
|
||||||
|
("zip", None, status.HTTP_200_OK),
|
||||||
|
("zip", 1, status.HTTP_200_OK),
|
||||||
|
("zip", 12, status.HTTP_400_BAD_REQUEST),
|
||||||
|
("bzip2", None, status.HTTP_200_OK),
|
||||||
|
("bzip2", 1, status.HTTP_200_OK),
|
||||||
|
("bzip2", 13, status.HTTP_400_BAD_REQUEST),
|
||||||
|
("lzma", None, status.HTTP_200_OK),
|
||||||
|
("lzma", 1, status.HTTP_400_BAD_REQUEST),
|
||||||
|
("zstd", None, status.HTTP_200_OK),
|
||||||
|
("zstd", 12, status.HTTP_200_OK),
|
||||||
|
("zstd", 23, status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async def test_export_compression(
|
||||||
|
app: FastAPI,
|
||||||
|
client: AsyncClient,
|
||||||
|
tmpdir,
|
||||||
|
project: Project,
|
||||||
|
compression: str,
|
||||||
|
compression_level: int,
|
||||||
|
status_code: int
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
project.dump = MagicMock()
|
||||||
|
os.makedirs(project.path, exist_ok=True)
|
||||||
|
|
||||||
|
topology = {
|
||||||
|
"topology": {
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"node_type": "qemu"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with open(os.path.join(project.path, "test.gns3"), 'w+') as f:
|
||||||
|
json.dump(topology, f)
|
||||||
|
|
||||||
|
params = {"compression": compression}
|
||||||
|
if compression_level:
|
||||||
|
params["compression_level"] = compression_level
|
||||||
|
response = await client.get(app.url_path_for("export_project", project_id=project.id), params=params)
|
||||||
|
assert response.status_code == status_code
|
||||||
|
|
||||||
|
if response.status_code == status.HTTP_200_OK:
|
||||||
|
assert response.headers['CONTENT-TYPE'] == 'application/gns3project'
|
||||||
|
assert response.headers['CONTENT-DISPOSITION'] == 'attachment; filename="{}.gns3project"'.format(project.name)
|
||||||
|
|
||||||
|
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||||
|
with myzip.open("project.gns3") as myfile:
|
||||||
|
myfile.read()
|
||||||
|
|
||||||
|
|
||||||
async def test_get_file(app: FastAPI, client: AsyncClient, project: Project) -> None:
|
async def test_get_file(app: FastAPI, client: AsyncClient, project: Project) -> None:
|
||||||
|
|
||||||
os.makedirs(project.path, exist_ok=True)
|
os.makedirs(project.path, exist_ok=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user