From d6f63d3b7d8373edebcbca68d03eac8d288c4868 Mon Sep 17 00:00:00 2001 From: Julien Duponchelle Date: Thu, 28 Jul 2016 14:59:44 +0200 Subject: [PATCH] Fix Exporting portable projects with QEMU includes base images even when selecting no. Fix https://github.com/GNS3/gns3-gui/issues/1409 --- gns3server/handlers/api/project_handler.py | 3 +- gns3server/web/route.py | 34 +++++++++++++++------- tests/handlers/api/test_project.py | 15 +++++++++- 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/gns3server/handlers/api/project_handler.py b/gns3server/handlers/api/project_handler.py index b56e703f..2866223e 100644 --- a/gns3server/handlers/api/project_handler.py +++ b/gns3server/handlers/api/project_handler.py @@ -368,7 +368,8 @@ class ProjectHandler: response.content_length = None response.start(request) - for data in project.export(include_images=bool(request.GET.get("include_images", "0"))): + include_images = bool(int(request.json.get("include_images", "0"))) + for data in project.export(include_images=include_images): response.write(data) yield from response.drain() diff --git a/gns3server/web/route.py b/gns3server/web/route.py index c588a0a0..ef4c0612 100644 --- a/gns3server/web/route.py +++ b/gns3server/web/route.py @@ -17,11 +17,13 @@ import sys import json -import jsonschema +import urllib import asyncio import aiohttp import logging import traceback +import jsonschema + log = logging.getLogger(__name__) @@ -33,10 +35,11 @@ from ..config import Config @asyncio.coroutine -def parse_request(request, input_schema): +def parse_request(request, input_schema, raw): """Parse body of request and raise HTTP errors in case of problems""" + content_length = request.content_length - if content_length is not None and content_length > 0: + if content_length is not None and content_length > 0 and not raw: body = yield from request.read() try: request.json = json.loads(body.decode('utf-8')) @@ -45,13 +48,21 @@ def parse_request(request, input_schema): raise aiohttp.web.HTTPBadRequest(text="Invalid JSON {}".format(e)) else: request.json = {} - try: - jsonschema.validate(request.json, input_schema) - except jsonschema.ValidationError as e: - log.error("Invalid input query. JSON schema error: {}".format(e.message)) - raise aiohttp.web.HTTPBadRequest(text="Invalid JSON: {} in schema: {}".format( - e.message, - json.dumps(e.schema))) + + # Parse the query string + if len(request.query_string) > 0: + for (k, v) in urllib.parse.parse_qs(request.query_string).items(): + request.json[k] = v[0] + + if input_schema: + try: + jsonschema.validate(request.json, input_schema) + except jsonschema.ValidationError as e: + log.error("Invalid input query. JSON schema error: {}".format(e.message)) + raise aiohttp.web.HTTPBadRequest(text="Invalid JSON: {} in schema: {}".format( + e.message, + json.dumps(e.schema))) + return request @@ -161,12 +172,13 @@ class Route(object): if api_version is None or raw is True: response = Response(request=request, route=route, output_schema=output_schema) + request = yield from parse_request(request, None, raw) yield from func(request, response) return response # API call try: - request = yield from parse_request(request, input_schema) + request = yield from parse_request(request, input_schema, raw) record_file = server_config.get("record") if record_file: try: diff --git a/tests/handlers/api/test_project.py b/tests/handlers/api/test_project.py index 90a4a554..a467a9bf 100644 --- a/tests/handlers/api/test_project.py +++ b/tests/handlers/api/test_project.py @@ -25,7 +25,7 @@ import asyncio import aiohttp import zipfile -from unittest.mock import patch +from unittest.mock import patch, MagicMock from tests.utils import asyncio_patch from gns3server.handlers.api.project_handler import ProjectHandler @@ -306,6 +306,19 @@ def test_export(server, tmpdir, loop, project): assert content == b"hello" +def test_export_include_image(server, tmpdir, loop, project): + + project.export = MagicMock() + response = server.get("/projects/{project_id}/export".format(project_id=project.id), raw=True) + project.export.assert_called_with(include_images=False) + + response = server.get("/projects/{project_id}/export?include_images=0".format(project_id=project.id), raw=True) + project.export.assert_called_with(include_images=False) + + response = server.get("/projects/{project_id}/export?include_images=1".format(project_id=project.id), raw=True) + project.export.assert_called_with(include_images=True) + + def test_import(server, tmpdir, loop, project): with zipfile.ZipFile(str(tmpdir / "test.zip"), 'w') as myzip: