mirror of
https://github.com/GNS3/gns3-server.git
synced 2024-12-22 06:07:51 +00:00
Merge pull request #2070 from GNS3/project-export-zstd
zstandard compression support for project export
This commit is contained in:
commit
466aaf5c13
@ -21,10 +21,10 @@ API routes for projects.
|
||||
import os
|
||||
import asyncio
|
||||
import tempfile
|
||||
import zipfile
|
||||
import aiofiles
|
||||
import time
|
||||
import urllib.parse
|
||||
import gns3server.utils.zipfile_zstd as zipfile
|
||||
|
||||
import logging
|
||||
|
||||
@ -41,7 +41,7 @@ from pathlib import Path
|
||||
from gns3server import schemas
|
||||
from gns3server.controller import Controller
|
||||
from gns3server.controller.project import Project
|
||||
from gns3server.controller.controller_error import ControllerError, ControllerForbiddenError
|
||||
from gns3server.controller.controller_error import ControllerError, ControllerBadRequestError
|
||||
from gns3server.controller.import_project import import_project as import_controller_project
|
||||
from gns3server.controller.export_project import export_project as export_controller_project
|
||||
from gns3server.utils.asyncio import aiozipstream
|
||||
@ -285,7 +285,8 @@ async def export_project(
|
||||
include_snapshots: bool = False,
|
||||
include_images: bool = False,
|
||||
reset_mac_addresses: bool = False,
|
||||
compression: str = "zip",
|
||||
compression: schemas.ProjectCompression = "zstd",
|
||||
compression_level: int = None,
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Export a project as a portable archive.
|
||||
@ -294,12 +295,23 @@ async def export_project(
|
||||
compression_query = compression.lower()
|
||||
if compression_query == "zip":
|
||||
compression = zipfile.ZIP_DEFLATED
|
||||
if compression_level is not None and (compression_level < 0 or compression_level > 9):
|
||||
raise ControllerBadRequestError("Compression level must be between 0 and 9 for ZIP compression")
|
||||
elif compression_query == "none":
|
||||
compression = zipfile.ZIP_STORED
|
||||
elif compression_query == "bzip2":
|
||||
compression = zipfile.ZIP_BZIP2
|
||||
if compression_level is not None and (compression_level < 1 or compression_level > 9):
|
||||
raise ControllerBadRequestError("Compression level must be between 1 and 9 for BZIP2 compression")
|
||||
elif compression_query == "lzma":
|
||||
compression = zipfile.ZIP_LZMA
|
||||
elif compression_query == "zstd":
|
||||
compression = zipfile.ZIP_ZSTANDARD
|
||||
if compression_level is not None and (compression_level < 1 or compression_level > 22):
|
||||
raise ControllerBadRequestError("Compression level must be between 1 and 22 for Zstandard compression")
|
||||
|
||||
if compression_level is not None and compression_query in ("none", "lzma"):
|
||||
raise ControllerBadRequestError(f"Compression level is not supported for '{compression_query}' compression method")
|
||||
|
||||
try:
|
||||
begin = time.time()
|
||||
@ -307,8 +319,10 @@ async def export_project(
|
||||
working_dir = os.path.abspath(os.path.join(project.path, os.pardir))
|
||||
|
||||
async def streamer():
|
||||
log.info(f"Exporting project '{project.name}' with '{compression_query}' compression "
|
||||
f"(level {compression_level})")
|
||||
with tempfile.TemporaryDirectory(dir=working_dir) as tmpdir:
|
||||
with aiozipstream.ZipFile(compression=compression) as zstream:
|
||||
with aiozipstream.ZipFile(compression=compression, compresslevel=compression_level) as zstream:
|
||||
await export_controller_project(
|
||||
zstream,
|
||||
project,
|
||||
|
@ -166,12 +166,14 @@ async def sqlalchemry_error_handler(request: Request, exc: SQLAlchemyError):
|
||||
content={"message": "Database error detected, please check logs to find details"},
|
||||
)
|
||||
|
||||
# FIXME: do not use this middleware since it creates issue when using StreamingResponse
|
||||
# see https://starlette-context.readthedocs.io/en/latest/middleware.html#why-are-there-two-middlewares-that-do-the-same-thing
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_extra_headers(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
response.headers["X-GNS3-Server-Version"] = f"{__version__}"
|
||||
return response
|
||||
# @app.middleware("http")
|
||||
# async def add_extra_headers(request: Request, call_next):
|
||||
# start_time = time.time()
|
||||
# response = await call_next(request)
|
||||
# process_time = time.time() - start_time
|
||||
# response.headers["X-Process-Time"] = str(process_time)
|
||||
# response.headers["X-GNS3-Server-Version"] = f"{__version__}"
|
||||
# return response
|
||||
|
@ -16,7 +16,6 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
import aiofiles
|
||||
|
@ -20,10 +20,10 @@ import sys
|
||||
import json
|
||||
import uuid
|
||||
import shutil
|
||||
import zipfile
|
||||
import aiofiles
|
||||
import itertools
|
||||
import tempfile
|
||||
import gns3server.utils.zipfile_zstd as zipfile_zstd
|
||||
|
||||
from .controller_error import ControllerError
|
||||
from .topology import load_topology
|
||||
@ -60,9 +60,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non
|
||||
raise ControllerError("The destination path should not contain .gns3")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(stream) as zip_file:
|
||||
with zipfile_zstd.ZipFile(stream) as zip_file:
|
||||
project_file = zip_file.read("project.gns3").decode()
|
||||
except zipfile.BadZipFile:
|
||||
except zipfile_zstd.BadZipFile:
|
||||
raise ControllerError("Cannot import project, not a GNS3 project (invalid zip)")
|
||||
except KeyError:
|
||||
raise ControllerError("Cannot import project, project.gns3 file could not be found")
|
||||
@ -92,9 +92,9 @@ async def import_project(controller, project_id, stream, location=None, name=Non
|
||||
raise ControllerError("The project name contain non supported or invalid characters")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(stream) as zip_file:
|
||||
with zipfile_zstd.ZipFile(stream) as zip_file:
|
||||
await wait_run_in_executor(zip_file.extractall, path)
|
||||
except zipfile.BadZipFile:
|
||||
except zipfile_zstd.BadZipFile:
|
||||
raise ControllerError("Cannot extract files from GNS3 project (invalid zip)")
|
||||
|
||||
topology = load_topology(os.path.join(path, "project.gns3"))
|
||||
@ -264,11 +264,11 @@ async def _import_snapshots(snapshots_path, project_name, project_id):
|
||||
# extract everything to a temporary directory
|
||||
try:
|
||||
with open(snapshot_path, "rb") as f:
|
||||
with zipfile.ZipFile(f) as zip_file:
|
||||
with zipfile_zstd.ZipFile(f) as zip_file:
|
||||
await wait_run_in_executor(zip_file.extractall, tmpdir)
|
||||
except OSError as e:
|
||||
raise ControllerError(f"Cannot open snapshot '{os.path.basename(snapshot)}': {e}")
|
||||
except zipfile.BadZipFile:
|
||||
except zipfile_zstd.BadZipFile:
|
||||
raise ControllerError(
|
||||
f"Cannot extract files from snapshot '{os.path.basename(snapshot)}': not a GNS3 project (invalid zip)"
|
||||
)
|
||||
@ -294,7 +294,7 @@ async def _import_snapshots(snapshots_path, project_name, project_id):
|
||||
|
||||
# write everything back to the original snapshot file
|
||||
try:
|
||||
with aiozipstream.ZipFile(compression=zipfile.ZIP_STORED) as zstream:
|
||||
with aiozipstream.ZipFile(compression=zipfile_zstd.ZIP_STORED) as zstream:
|
||||
for root, dirs, files in os.walk(tmpdir, topdown=True, followlinks=False):
|
||||
for file in files:
|
||||
path = os.path.join(root, file)
|
||||
|
@ -28,7 +28,7 @@ from .controller.appliances import ApplianceVersion, Appliance
|
||||
from .controller.drawings import Drawing
|
||||
from .controller.gns3vm import GNS3VM
|
||||
from .controller.nodes import NodeCreate, NodeUpdate, NodeDuplicate, NodeCapture, Node
|
||||
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile
|
||||
from .controller.projects import ProjectCreate, ProjectUpdate, ProjectDuplicate, Project, ProjectFile, ProjectCompression
|
||||
from .controller.users import UserCreate, UserUpdate, LoggedInUserUpdate, User, Credentials, UserGroupCreate, UserGroupUpdate, UserGroup
|
||||
from .controller.rbac import RoleCreate, RoleUpdate, Role, PermissionCreate, PermissionUpdate, Permission
|
||||
from .controller.tokens import Token
|
||||
|
@ -102,3 +102,15 @@ class ProjectFile(BaseModel):
|
||||
|
||||
path: str = Field(..., description="File path")
|
||||
md5sum: str = Field(..., description="File checksum")
|
||||
|
||||
|
||||
class ProjectCompression(str, Enum):
|
||||
"""
|
||||
Supported project compression.
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
zip = "zip"
|
||||
bzip2 = "bzip2"
|
||||
lzma = "lzma"
|
||||
zstd = "zstd"
|
||||
|
@ -43,26 +43,38 @@ from zipfile import (
|
||||
stringEndArchive64Locator,
|
||||
)
|
||||
|
||||
|
||||
ZIP_ZSTANDARD = 93 # zstandard is supported by WinZIP v24 and later, PowerArchiver 2021 and 7-Zip-zstd
|
||||
ZSTANDARD_VERSION = 20
|
||||
stringDataDescriptor = b"PK\x07\x08" # magic number for data descriptor
|
||||
|
||||
|
||||
def _get_compressor(compress_type):
|
||||
def _get_compressor(compress_type, compresslevel=None):
|
||||
"""
|
||||
Return the compressor.
|
||||
"""
|
||||
|
||||
if compress_type == zipfile.ZIP_DEFLATED:
|
||||
from zipfile import zlib
|
||||
|
||||
if compresslevel is not None:
|
||||
return zlib.compressobj(compresslevel, zlib.DEFLATED, -15)
|
||||
return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -15)
|
||||
elif compress_type == zipfile.ZIP_BZIP2:
|
||||
from zipfile import bz2
|
||||
|
||||
if compresslevel is not None:
|
||||
return bz2.BZ2Compressor(compresslevel)
|
||||
return bz2.BZ2Compressor()
|
||||
# compresslevel is ignored for ZIP_LZMA
|
||||
elif compress_type == zipfile.ZIP_LZMA:
|
||||
from zipfile import LZMACompressor
|
||||
|
||||
return LZMACompressor()
|
||||
elif compress_type == ZIP_ZSTANDARD:
|
||||
import zstandard as zstd
|
||||
if compresslevel is not None:
|
||||
#params = zstd.ZstdCompressionParameters.from_level(compresslevel, threads=-1, enable_ldm=True, window_log=31)
|
||||
#return zstd.ZstdCompressor(compression_params=params).compressobj()
|
||||
return zstd.ZstdCompressor(level=compresslevel).compressobj()
|
||||
return zstd.ZstdCompressor().compressobj()
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -129,7 +141,15 @@ class ZipInfo(zipfile.ZipInfo):
|
||||
|
||||
|
||||
class ZipFile(zipfile.ZipFile):
|
||||
def __init__(self, fileobj=None, mode="w", compression=zipfile.ZIP_STORED, allowZip64=True, chunksize=32768):
|
||||
def __init__(
|
||||
self,
|
||||
fileobj=None,
|
||||
mode="w",
|
||||
compression=zipfile.ZIP_STORED,
|
||||
allowZip64=True,
|
||||
compresslevel=None,
|
||||
chunksize=32768
|
||||
):
|
||||
"""Open the ZIP file with mode write "w"."""
|
||||
|
||||
if mode not in ("w",):
|
||||
@ -138,7 +158,13 @@ class ZipFile(zipfile.ZipFile):
|
||||
fileobj = PointerIO()
|
||||
|
||||
self._comment = b""
|
||||
zipfile.ZipFile.__init__(self, fileobj, mode=mode, compression=compression, allowZip64=allowZip64)
|
||||
zipfile.ZipFile.__init__(
|
||||
self, fileobj,
|
||||
mode=mode,
|
||||
compression=compression,
|
||||
compresslevel=compresslevel,
|
||||
allowZip64=allowZip64
|
||||
)
|
||||
self._chunksize = chunksize
|
||||
self.paths_to_write = []
|
||||
|
||||
@ -195,23 +221,33 @@ class ZipFile(zipfile.ZipFile):
|
||||
for chunk in self._close():
|
||||
yield chunk
|
||||
|
||||
def write(self, filename, arcname=None, compress_type=None):
|
||||
def write(self, filename, arcname=None, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Write a file to the archive under the name `arcname`.
|
||||
"""
|
||||
|
||||
kwargs = {"filename": filename, "arcname": arcname, "compress_type": compress_type}
|
||||
kwargs = {
|
||||
"filename": filename,
|
||||
"arcname": arcname,
|
||||
"compress_type": compress_type,
|
||||
"compresslevel": compresslevel
|
||||
}
|
||||
self.paths_to_write.append(kwargs)
|
||||
|
||||
def write_iter(self, arcname, iterable, compress_type=None):
|
||||
def write_iter(self, arcname, iterable, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Write the bytes iterable `iterable` to the archive under the name `arcname`.
|
||||
"""
|
||||
|
||||
kwargs = {"arcname": arcname, "iterable": iterable, "compress_type": compress_type}
|
||||
kwargs = {
|
||||
"arcname": arcname,
|
||||
"iterable": iterable,
|
||||
"compress_type": compress_type,
|
||||
"compresslevel": compresslevel
|
||||
}
|
||||
self.paths_to_write.append(kwargs)
|
||||
|
||||
def writestr(self, arcname, data, compress_type=None):
|
||||
def writestr(self, arcname, data, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Writes a str into ZipFile by wrapping data as a generator
|
||||
"""
|
||||
@ -219,9 +255,9 @@ class ZipFile(zipfile.ZipFile):
|
||||
def _iterable():
|
||||
yield data
|
||||
|
||||
return self.write_iter(arcname, _iterable(), compress_type=compress_type)
|
||||
return self.write_iter(arcname, _iterable(), compress_type=compress_type, compresslevel=compresslevel)
|
||||
|
||||
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None):
|
||||
async def _write(self, filename=None, iterable=None, arcname=None, compress_type=None, compresslevel=None):
|
||||
"""
|
||||
Put the bytes from filename into the archive under the name `arcname`.
|
||||
"""
|
||||
@ -256,6 +292,11 @@ class ZipFile(zipfile.ZipFile):
|
||||
else:
|
||||
zinfo.compress_type = compress_type
|
||||
|
||||
if compresslevel is None:
|
||||
zinfo._compresslevel = self.compresslevel
|
||||
else:
|
||||
zinfo._compresslevel = compresslevel
|
||||
|
||||
if st:
|
||||
zinfo.file_size = st[6]
|
||||
else:
|
||||
@ -279,7 +320,7 @@ class ZipFile(zipfile.ZipFile):
|
||||
yield self.fp.write(zinfo.FileHeader(False))
|
||||
return
|
||||
|
||||
cmpr = _get_compressor(zinfo.compress_type)
|
||||
cmpr = _get_compressor(zinfo.compress_type, zinfo._compresslevel)
|
||||
|
||||
# Must overwrite CRC and sizes with correct data later
|
||||
zinfo.CRC = CRC = 0
|
||||
@ -369,6 +410,8 @@ class ZipFile(zipfile.ZipFile):
|
||||
min_version = max(zipfile.BZIP2_VERSION, min_version)
|
||||
elif zinfo.compress_type == zipfile.ZIP_LZMA:
|
||||
min_version = max(zipfile.LZMA_VERSION, min_version)
|
||||
elif zinfo.compress_type == ZIP_ZSTANDARD:
|
||||
min_version = max(ZSTANDARD_VERSION, min_version)
|
||||
|
||||
extract_version = max(min_version, zinfo.extract_version)
|
||||
create_version = max(min_version, zinfo.create_version)
|
||||
|
10
gns3server/utils/zipfile_zstd/__init__.py
Normal file
10
gns3server/utils/zipfile_zstd/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
# NOTE: this patches the standard zipfile module
|
||||
from . import _zipfile
|
||||
|
||||
from zipfile import *
|
||||
from zipfile import (
|
||||
ZIP_ZSTANDARD,
|
||||
ZSTANDARD_VERSION,
|
||||
)
|
||||
|
20
gns3server/utils/zipfile_zstd/_patcher.py
Normal file
20
gns3server/utils/zipfile_zstd/_patcher.py
Normal file
@ -0,0 +1,20 @@
|
||||
import functools
|
||||
|
||||
|
||||
class patch:
|
||||
|
||||
originals = {}
|
||||
|
||||
def __init__(self, host, name):
|
||||
self.host = host
|
||||
self.name = name
|
||||
|
||||
def __call__(self, func):
|
||||
original = getattr(self.host, self.name)
|
||||
self.originals[self.name] = original
|
||||
|
||||
functools.update_wrapper(func, original)
|
||||
setattr(self.host, self.name, func)
|
||||
|
||||
return func
|
||||
|
64
gns3server/utils/zipfile_zstd/_zipfile.py
Normal file
64
gns3server/utils/zipfile_zstd/_zipfile.py
Normal file
@ -0,0 +1,64 @@
|
||||
import zipfile
|
||||
import zstandard as zstd
|
||||
import inspect
|
||||
|
||||
from ._patcher import patch
|
||||
|
||||
|
||||
zipfile.ZIP_ZSTANDARD = 93
|
||||
zipfile.compressor_names[zipfile.ZIP_ZSTANDARD] = 'zstandard'
|
||||
zipfile.ZSTANDARD_VERSION = 20
|
||||
|
||||
|
||||
@patch(zipfile, '_check_compression')
|
||||
def zstd_check_compression(compression):
|
||||
if compression == zipfile.ZIP_ZSTANDARD:
|
||||
pass
|
||||
else:
|
||||
patch.originals['_check_compression'](compression)
|
||||
|
||||
|
||||
class ZstdDecompressObjWrapper:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr == 'eof':
|
||||
return False
|
||||
return getattr(self.o, attr)
|
||||
|
||||
|
||||
@patch(zipfile, '_get_decompressor')
|
||||
def zstd_get_decompressor(compress_type):
|
||||
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
return ZstdDecompressObjWrapper(zstd.ZstdDecompressor(max_window_size=2147483648).decompressobj())
|
||||
else:
|
||||
return patch.originals['_get_decompressor'](compress_type)
|
||||
|
||||
|
||||
if 'compresslevel' in inspect.signature(zipfile._get_compressor).parameters:
|
||||
@patch(zipfile, '_get_compressor')
|
||||
def zstd_get_compressor(compress_type, compresslevel=None):
|
||||
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
if compresslevel is None:
|
||||
compresslevel = 3
|
||||
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
|
||||
else:
|
||||
return patch.originals['_get_compressor'](compress_type, compresslevel=compresslevel)
|
||||
else:
|
||||
@patch(zipfile, '_get_compressor')
|
||||
def zstd_get_compressor(compress_type, compresslevel=None):
|
||||
if compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
if compresslevel is None:
|
||||
compresslevel = 3
|
||||
return zstd.ZstdCompressor(level=compresslevel, threads=12).compressobj()
|
||||
else:
|
||||
return patch.originals['_get_compressor'](compress_type)
|
||||
|
||||
|
||||
@patch(zipfile.ZipInfo, 'FileHeader')
|
||||
def zstd_FileHeader(self, zip64=None):
|
||||
if self.compress_type == zipfile.ZIP_ZSTANDARD:
|
||||
self.create_version = max(self.create_version, zipfile.ZSTANDARD_VERSION)
|
||||
self.extract_version = max(self.extract_version, zipfile.ZSTANDARD_VERSION)
|
||||
return patch.originals['FileHeader'](self, zip64=zip64)
|
@ -16,4 +16,5 @@ passlib[bcrypt]==1.7.4
|
||||
python-jose==3.3.0
|
||||
email-validator==1.2.1
|
||||
watchfiles==0.14.1
|
||||
zstandard==0.17.0
|
||||
setuptools==60.6.0 # don't upgrade because of https://github.com/pypa/setuptools/issues/3084
|
||||
|
@ -17,7 +17,6 @@
|
||||
|
||||
import uuid
|
||||
import os
|
||||
import zipfile
|
||||
import json
|
||||
import pytest
|
||||
|
||||
@ -26,6 +25,7 @@ from httpx import AsyncClient
|
||||
from unittest.mock import patch, MagicMock
|
||||
from tests.utils import asyncio_patch
|
||||
|
||||
import gns3server.utils.zipfile_zstd as zipfile_zstd
|
||||
from gns3server.controller import Controller
|
||||
from gns3server.controller.project import Project
|
||||
|
||||
@ -261,7 +261,7 @@ async def test_export_with_images(app: FastAPI, client: AsyncClient, tmpdir, pro
|
||||
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
||||
f.write(response.content)
|
||||
|
||||
with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||
with myzip.open("a") as myfile:
|
||||
content = myfile.read()
|
||||
assert content == b"hello"
|
||||
@ -304,7 +304,7 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir,
|
||||
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
||||
f.write(response.content)
|
||||
|
||||
with zipfile.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||
with myzip.open("a") as myfile:
|
||||
content = myfile.read()
|
||||
assert content == b"hello"
|
||||
@ -313,6 +313,67 @@ async def test_export_without_images(app: FastAPI, client: AsyncClient, tmpdir,
|
||||
myzip.getinfo("images/IOS/test.image")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"compression, compression_level, status_code",
|
||||
(
|
||||
("none", None, status.HTTP_200_OK),
|
||||
("none", 4, status.HTTP_400_BAD_REQUEST),
|
||||
("zip", None, status.HTTP_200_OK),
|
||||
("zip", 1, status.HTTP_200_OK),
|
||||
("zip", 12, status.HTTP_400_BAD_REQUEST),
|
||||
("bzip2", None, status.HTTP_200_OK),
|
||||
("bzip2", 1, status.HTTP_200_OK),
|
||||
("bzip2", 13, status.HTTP_400_BAD_REQUEST),
|
||||
("lzma", None, status.HTTP_200_OK),
|
||||
("lzma", 1, status.HTTP_400_BAD_REQUEST),
|
||||
("zstd", None, status.HTTP_200_OK),
|
||||
("zstd", 12, status.HTTP_200_OK),
|
||||
("zstd", 23, status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
)
|
||||
async def test_export_compression(
|
||||
app: FastAPI,
|
||||
client: AsyncClient,
|
||||
tmpdir,
|
||||
project: Project,
|
||||
compression: str,
|
||||
compression_level: int,
|
||||
status_code: int
|
||||
) -> None:
|
||||
|
||||
project.dump = MagicMock()
|
||||
os.makedirs(project.path, exist_ok=True)
|
||||
|
||||
topology = {
|
||||
"topology": {
|
||||
"nodes": [
|
||||
{
|
||||
"node_type": "qemu"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
with open(os.path.join(project.path, "test.gns3"), 'w+') as f:
|
||||
json.dump(topology, f)
|
||||
|
||||
params = {"compression": compression}
|
||||
if compression_level:
|
||||
params["compression_level"] = compression_level
|
||||
response = await client.get(app.url_path_for("export_project", project_id=project.id), params=params)
|
||||
assert response.status_code == status_code
|
||||
|
||||
if response.status_code == status.HTTP_200_OK:
|
||||
assert response.headers['CONTENT-TYPE'] == 'application/gns3project'
|
||||
assert response.headers['CONTENT-DISPOSITION'] == 'attachment; filename="{}.gns3project"'.format(project.name)
|
||||
|
||||
with open(str(tmpdir / 'project.zip'), 'wb+') as f:
|
||||
f.write(response.content)
|
||||
|
||||
with zipfile_zstd.ZipFile(str(tmpdir / 'project.zip')) as myzip:
|
||||
with myzip.open("project.gns3") as myfile:
|
||||
myfile.read()
|
||||
|
||||
|
||||
async def test_get_file(app: FastAPI, client: AsyncClient, project: Project) -> None:
|
||||
|
||||
os.makedirs(project.path, exist_ok=True)
|
||||
|
Loading…
Reference in New Issue
Block a user