diff --git a/gns3server/endpoints/controller/projects.py b/gns3server/endpoints/controller/projects.py index 4e20aa42..16d18e0f 100644 --- a/gns3server/endpoints/controller/projects.py +++ b/gns3server/endpoints/controller/projects.py @@ -29,12 +29,13 @@ import time import logging log = logging.getLogger() -from fastapi import APIRouter, Depends, Request, Body, HTTPException, status, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Request, Body, Query, HTTPException, status, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse, FileResponse from websockets.exceptions import ConnectionClosed, WebSocketException -from typing import List +from typing import List, Optional from uuid import UUID +from pathlib import Path from gns3server.endpoints.schemas.common import ErrorMessage from gns3server.endpoints import schemas @@ -307,20 +308,16 @@ async def export_project(project: Project = Depends(dep_project), status_code=status.HTTP_201_CREATED, response_model=schemas.Project, responses=responses) -async def import_project(project_id: UUID, request: Request): +async def import_project(project_id: UUID, request: Request, path: Optional[Path] = None, name: Optional[str] = None): """ Import a project from a portable archive. """ controller = Controller.instance() config = Config.instance() - if config.get_section_config("Server").getboolean("local", False) is False: + if not config.get_section_config("Server").getboolean("local", False): raise ControllerForbiddenError("The server is not local") - #FIXME: broken - path = None - name = "test" - # We write the content to a temporary location and after we extract it all. # It could be more optimal to stream this but it is not implemented in Python. try: diff --git a/tests/conftest.py b/tests/conftest.py index 4d71b007..e89628fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -259,6 +259,7 @@ def run_around_tests(monkeypatch, config, port_manager):#port_manager, controlle config.set("Server", "appliances_path", os.path.join(tmppath, 'appliances')) config.set("Server", "ubridge_path", os.path.join(tmppath, 'bin', 'ubridge')) config.set("Server", "auth", False) + config.set("Server", "local", True) # Prevent executions of the VM if we forgot to mock something config.set("VirtualBox", "vboxmanage_path", tmppath) diff --git a/tests/endpoints/controller/test_projects.py b/tests/endpoints/controller/test_projects.py index 195d824b..81946dbe 100644 --- a/tests/endpoints/controller/test_projects.py +++ b/tests/endpoints/controller/test_projects.py @@ -21,10 +21,8 @@ import pytest import zipfile import json -from fastapi.testclient import TestClient from unittest.mock import patch, MagicMock from tests.utils import asyncio_patch -from gns3server.app import app @pytest.fixture @@ -362,22 +360,22 @@ async def test_write_and_get_file_with_leading_slashes_in_filename(controller_ap assert response.content == b"world" -# @pytest.mark.asyncio -# async def test_import(controller_api, tmpdir, controller): -# -# with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: -# myzip.writestr("project.gns3", b'{"project_id": "c6992992-ac72-47dc-833b-54aa334bcd05", "version": "2.0.0", "name": "test"}') -# myzip.writestr("demo", b"hello") -# -# project_id = str(uuid.uuid4()) -# with open(str(tmpdir / "test.zip"), "rb") as f: -# response = await controller_api.post("/projects/{project_id}/import".format(project_id=project_id), body=f.read(), raw=True) -# assert response.status_code == 201 -# -# project = controller.get_project(project_id) -# with open(os.path.join(project.path, "demo")) as f: -# content = f.read() -# assert content == "hello" +@pytest.mark.asyncio +async def test_import(controller_api, tmpdir, controller): + + with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: + myzip.writestr("project.gns3", b'{"project_id": "c6992992-ac72-47dc-833b-54aa334bcd05", "version": "2.0.0", "name": "test"}') + myzip.writestr("demo", b"hello") + + project_id = str(uuid.uuid4()) + with open(str(tmpdir / "test.zip"), "rb") as f: + response = await controller_api.post("/projects/{project_id}/import".format(project_id=project_id), body=f.read(), raw=True) + assert response.status_code == 201 + + project = controller.get_project(project_id) + with open(os.path.join(project.path, "demo")) as f: + content = f.read() + assert content == "hello" @pytest.mark.asyncio