mirror of
https://github.com/GNS3/gns3-server.git
synced 2025-06-16 06:18:19 +00:00
Use aiofiles where relevant.
This commit is contained in:
@ -37,6 +37,8 @@ from gns3server.schemas.project import (
|
||||
import logging
|
||||
log = logging.getLogger()
|
||||
|
||||
CHUNK_SIZE = 1024 * 8 # 8KB
|
||||
|
||||
|
||||
class ProjectHandler:
|
||||
|
||||
@ -248,64 +250,7 @@ class ProjectHandler:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
path = os.path.join(project.path, path)
|
||||
|
||||
response.content_type = "application/octet-stream"
|
||||
response.set_status(200)
|
||||
response.enable_chunked_encoding()
|
||||
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
await response.prepare(request)
|
||||
while True:
|
||||
data = f.read(4096)
|
||||
if not data:
|
||||
break
|
||||
await response.write(data)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
@Route.get(
|
||||
r"/projects/{project_id}/stream/{path:.+}",
|
||||
description="Stream a file from a project",
|
||||
parameters={
|
||||
"project_id": "Project UUID",
|
||||
},
|
||||
status_codes={
|
||||
200: "File returned",
|
||||
403: "Permission denied",
|
||||
404: "The file doesn't exist"
|
||||
})
|
||||
async def stream_file(request, response):
|
||||
|
||||
pm = ProjectManager.instance()
|
||||
project = pm.get_project(request.match_info["project_id"])
|
||||
path = request.match_info["path"]
|
||||
path = os.path.normpath(path)
|
||||
|
||||
# Raise an error if user try to escape
|
||||
if path[0] == ".":
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
path = os.path.join(project.path, path)
|
||||
|
||||
response.content_type = "application/octet-stream"
|
||||
response.set_status(200)
|
||||
response.enable_chunked_encoding()
|
||||
|
||||
# FIXME: file streaming is never stopped
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
await response.prepare(request)
|
||||
while True:
|
||||
data = f.read(4096)
|
||||
if not data:
|
||||
await asyncio.sleep(0.1)
|
||||
await response.write(data)
|
||||
except FileNotFoundError:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
await response.stream_file(path)
|
||||
|
||||
@Route.post(
|
||||
r"/projects/{project_id}/files/{path:.+}",
|
||||
@ -338,7 +283,7 @@ class ProjectHandler:
|
||||
with open(path, 'wb+') as f:
|
||||
while True:
|
||||
try:
|
||||
chunk = await request.content.read(1024)
|
||||
chunk = await request.content.read(CHUNK_SIZE)
|
||||
except asyncio.TimeoutError:
|
||||
raise aiohttp.web.HTTPRequestTimeout(text="Timeout when writing to file '{}'".format(path))
|
||||
if not chunk:
|
||||
@ -349,64 +294,3 @@ class ProjectHandler:
|
||||
raise aiohttp.web.HTTPNotFound()
|
||||
except PermissionError:
|
||||
raise aiohttp.web.HTTPForbidden()
|
||||
|
||||
@Route.get(
|
||||
r"/projects/{project_id}/export",
|
||||
description="Export a project as a portable archive",
|
||||
parameters={
|
||||
"project_id": "Project UUID",
|
||||
},
|
||||
raw=True,
|
||||
status_codes={
|
||||
200: "File returned",
|
||||
404: "The project doesn't exist"
|
||||
})
|
||||
async def export_project(request, response):
|
||||
|
||||
pm = ProjectManager.instance()
|
||||
project = pm.get_project(request.match_info["project_id"])
|
||||
response.content_type = 'application/gns3project'
|
||||
response.headers['CONTENT-DISPOSITION'] = 'attachment; filename="{}.gns3project"'.format(project.name)
|
||||
response.enable_chunked_encoding()
|
||||
await response.prepare(request)
|
||||
|
||||
include_images = bool(int(request.json.get("include_images", "0")))
|
||||
for data in project.export(include_images=include_images):
|
||||
await response.write(data)
|
||||
|
||||
#await response.write_eof() #FIXME: shound't be needed anymore
|
||||
|
||||
@Route.post(
|
||||
r"/projects/{project_id}/import",
|
||||
description="Import a project from a portable archive",
|
||||
parameters={
|
||||
"project_id": "Project UUID",
|
||||
},
|
||||
raw=True,
|
||||
output=PROJECT_OBJECT_SCHEMA,
|
||||
status_codes={
|
||||
200: "Project imported",
|
||||
403: "Forbidden to import project"
|
||||
})
|
||||
async def import_project(request, response):
|
||||
|
||||
pm = ProjectManager.instance()
|
||||
project_id = request.match_info["project_id"]
|
||||
project = pm.create_project(project_id=project_id)
|
||||
|
||||
# We write the content to a temporary location and after we extract it all.
|
||||
# It could be more optimal to stream this but it is not implemented in Python.
|
||||
# Spooled means the file is temporary kept in memory until max_size is reached
|
||||
try:
|
||||
with tempfile.SpooledTemporaryFile(max_size=10000) as temp:
|
||||
while True:
|
||||
chunk = await request.content.read(1024)
|
||||
if not chunk:
|
||||
break
|
||||
temp.write(chunk)
|
||||
project.import_zip(temp, gns3vm=bool(int(request.GET.get("gns3vm", "1"))))
|
||||
except OSError as e:
|
||||
raise aiohttp.web.HTTPInternalServerError(text="Could not import the project: {}".format(e))
|
||||
|
||||
response.json(project)
|
||||
response.set_status(201)
|
||||
|
Reference in New Issue
Block a user