Add HTTP client to reuse the aiohttp session where needed.

Remove unnecessary aiohttp exceptions.
This commit is contained in:
grossmj 2020-10-22 16:19:44 +10:30
parent 36c8920cd1
commit a92c47b310
17 changed files with 221 additions and 151 deletions

View File

@ -41,6 +41,7 @@ from gns3server.controller.controller_error import (
from gns3server.endpoints import controller from gns3server.endpoints import controller
from gns3server.endpoints import index from gns3server.endpoints import index
from gns3server.endpoints.compute import compute_api from gns3server.endpoints.compute import compute_api
from gns3server.utils.http_client import HTTPClient
from gns3server.version import __version__ from gns3server.version import __version__
import logging import logging
@ -76,6 +77,7 @@ app.mount("/v2/compute", compute_api)
@app.exception_handler(ControllerError) @app.exception_handler(ControllerError)
async def controller_error_handler(request: Request, exc: ControllerError): async def controller_error_handler(request: Request, exc: ControllerError):
log.error(f"Controller error: {exc}")
return JSONResponse( return JSONResponse(
status_code=409, status_code=409,
content={"message": str(exc)}, content={"message": str(exc)},
@ -84,6 +86,7 @@ async def controller_error_handler(request: Request, exc: ControllerError):
@app.exception_handler(ControllerTimeoutError) @app.exception_handler(ControllerTimeoutError)
async def controller_timeout_error_handler(request: Request, exc: ControllerTimeoutError): async def controller_timeout_error_handler(request: Request, exc: ControllerTimeoutError):
log.error(f"Controller timeout error: {exc}")
return JSONResponse( return JSONResponse(
status_code=408, status_code=408,
content={"message": str(exc)}, content={"message": str(exc)},
@ -92,6 +95,7 @@ async def controller_timeout_error_handler(request: Request, exc: ControllerTime
@app.exception_handler(ControllerUnauthorizedError) @app.exception_handler(ControllerUnauthorizedError)
async def controller_unauthorized_error_handler(request: Request, exc: ControllerUnauthorizedError): async def controller_unauthorized_error_handler(request: Request, exc: ControllerUnauthorizedError):
log.error(f"Controller unauthorized error: {exc}")
return JSONResponse( return JSONResponse(
status_code=401, status_code=401,
content={"message": str(exc)}, content={"message": str(exc)},
@ -100,6 +104,7 @@ async def controller_unauthorized_error_handler(request: Request, exc: Controlle
@app.exception_handler(ControllerForbiddenError) @app.exception_handler(ControllerForbiddenError)
async def controller_forbidden_error_handler(request: Request, exc: ControllerForbiddenError): async def controller_forbidden_error_handler(request: Request, exc: ControllerForbiddenError):
log.error(f"Controller forbidden error: {exc}")
return JSONResponse( return JSONResponse(
status_code=403, status_code=403,
content={"message": str(exc)}, content={"message": str(exc)},
@ -108,6 +113,7 @@ async def controller_forbidden_error_handler(request: Request, exc: ControllerFo
@app.exception_handler(ControllerNotFoundError) @app.exception_handler(ControllerNotFoundError)
async def controller_not_found_error_handler(request: Request, exc: ControllerNotFoundError): async def controller_not_found_error_handler(request: Request, exc: ControllerNotFoundError):
log.error(f"Controller not found error: {exc}")
return JSONResponse( return JSONResponse(
status_code=404, status_code=404,
content={"message": str(exc)}, content={"message": str(exc)},
@ -164,13 +170,7 @@ async def startup_event():
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
# close websocket connections await HTTPClient.close_session()
# websocket_connections = set(self._app['websockets'])
# if websocket_connections:
# log.info("Closing {} websocket connections...".format(len(websocket_connections)))
# for ws in websocket_connections:
# await ws.close(code=aiohttp.WSCloseCode.GOING_AWAY, message='Server shutdown')
await Controller.instance().stop() await Controller.instance().stop()
for module in MODULES: for module in MODULES:

View File

@ -60,9 +60,8 @@ class Docker(BaseManager):
if not self._connected: if not self._connected:
try: try:
self._connected = True self._connected = True
connector = self.connector()
version = await self.query("GET", "version") version = await self.query("GET", "version")
except (aiohttp.ClientOSError, FileNotFoundError): except (aiohttp.ClientError, FileNotFoundError):
self._connected = False self._connected = False
raise DockerError("Can't connect to docker daemon") raise DockerError("Can't connect to docker daemon")
@ -70,8 +69,8 @@ class Docker(BaseManager):
if docker_version < parse_version(DOCKER_MINIMUM_API_VERSION): if docker_version < parse_version(DOCKER_MINIMUM_API_VERSION):
raise DockerError( raise DockerError(
"Docker version is {}. GNS3 requires a minimum version of {}".format( "Docker version is {}. GNS3 requires a minimum version of {}".format(version["Version"],
version["Version"], DOCKER_MINIMUM_VERSION)) DOCKER_MINIMUM_VERSION))
preferred_api_version = parse_version(DOCKER_PREFERRED_API_VERSION) preferred_api_version = parse_version(DOCKER_PREFERRED_API_VERSION)
if docker_version >= preferred_api_version: if docker_version >= preferred_api_version:
@ -84,7 +83,7 @@ class Docker(BaseManager):
raise DockerError("Docker is supported only on Linux") raise DockerError("Docker is supported only on Linux")
try: try:
self._connector = aiohttp.connector.UnixConnector(self._server_url, limit=None) self._connector = aiohttp.connector.UnixConnector(self._server_url, limit=None)
except (aiohttp.ClientOSError, FileNotFoundError): except (aiohttp.ClientError, FileNotFoundError):
raise DockerError("Can't connect to docker daemon") raise DockerError("Can't connect to docker daemon")
return self._connector return self._connector
@ -150,7 +149,7 @@ class Docker(BaseManager):
data=data, data=data,
headers={"content-type": "application/json", }, headers={"content-type": "application/json", },
timeout=timeout) timeout=timeout)
except (aiohttp.ClientResponseError, aiohttp.ClientOSError) as e: except aiohttp.ClientError as e:
raise DockerError("Docker has returned an error: {}".format(str(e))) raise DockerError("Docker has returned an error: {}".format(str(e)))
except (asyncio.TimeoutError): except (asyncio.TimeoutError):
raise DockerError("Docker timeout " + method + " " + path) raise DockerError("Docker timeout " + method + " " + path)

View File

@ -19,7 +19,6 @@
Dynamips server module. Dynamips server module.
""" """
import aiohttp
import sys import sys
import os import os
import shutil import shutil

View File

@ -16,7 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import socket import socket
from aiohttp.web import HTTPConflict from fastapi import HTTPException, status
from gns3server.config import Config from gns3server.config import Config
import logging import logging
@ -48,12 +48,12 @@ class PortManager:
console_start_port_range = server_config.getint("console_start_port_range", 5000) console_start_port_range = server_config.getint("console_start_port_range", 5000)
console_end_port_range = server_config.getint("console_end_port_range", 10000) console_end_port_range = server_config.getint("console_end_port_range", 10000)
self._console_port_range = (console_start_port_range, console_end_port_range) self._console_port_range = (console_start_port_range, console_end_port_range)
log.debug("Console port range is {}-{}".format(console_start_port_range, console_end_port_range)) log.debug(f"Console port range is {console_start_port_range}-{console_end_port_range}")
udp_start_port_range = server_config.getint("udp_start_port_range", 20000) udp_start_port_range = server_config.getint("udp_start_port_range", 20000)
udp_end_port_range = server_config.getint("udp_end_port_range", 30000) udp_end_port_range = server_config.getint("udp_end_port_range", 30000)
self._udp_port_range = (udp_start_port_range, udp_end_port_range) self._udp_port_range = (udp_start_port_range, udp_end_port_range)
log.debug("UDP port range is {}-{}".format(udp_start_port_range, udp_end_port_range)) log.debug(f"UDP port range is {udp_start_port_range}-{udp_end_port_range}")
@classmethod @classmethod
def instance(cls): def instance(cls):
@ -149,7 +149,8 @@ class PortManager:
""" """
if end_port < start_port: if end_port < start_port:
raise HTTPConflict(text="Invalid port range {}-{}".format(start_port, end_port)) raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"Invalid port range {start_port}-{end_port}")
last_exception = None last_exception = None
for port in range(start_port, end_port + 1): for port in range(start_port, end_port + 1):
@ -168,10 +169,9 @@ class PortManager:
else: else:
continue continue
raise HTTPConflict(text="Could not find a free port between {} and {} on host {}, last exception: {}".format(start_port, raise HTTPException(status_code=status.HTTP_409_CONFLICT,
end_port, detail=f"Could not find a free port between {start_port} and {end_port} on host {host},"
host, f" last exception: {last_exception}")
last_exception))
@staticmethod @staticmethod
def _check_port(host, port, socket_type): def _check_port(host, port, socket_type):
@ -212,7 +212,7 @@ class PortManager:
self._used_tcp_ports.add(port) self._used_tcp_ports.add(port)
project.record_tcp_port(port) project.record_tcp_port(port)
log.debug("TCP port {} has been allocated".format(port)) log.debug(f"TCP port {port} has been allocated")
return port return port
def reserve_tcp_port(self, port, project, port_range_start=None, port_range_end=None): def reserve_tcp_port(self, port, project, port_range_start=None, port_range_end=None):
@ -235,13 +235,14 @@ class PortManager:
if port in self._used_tcp_ports: if port in self._used_tcp_ports:
old_port = port old_port = port
port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end) port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end)
msg = "TCP port {} already in use on host {}. Port has been replaced by {}".format(old_port, self._console_host, port) msg = f"TCP port {old_port} already in use on host {self._console_host}. Port has been replaced by {port}"
log.debug(msg) log.debug(msg)
return port return port
if port < port_range_start or port > port_range_end: if port < port_range_start or port > port_range_end:
old_port = port old_port = port
port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end) port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end)
msg = "TCP port {} is outside the range {}-{} on host {}. Port has been replaced by {}".format(old_port, port_range_start, port_range_end, self._console_host, port) msg = f"TCP port {old_port} is outside the range {port_range_start}-{port_range_end} on host " \
f"{self._console_host}. Port has been replaced by {port}"
log.debug(msg) log.debug(msg)
return port return port
try: try:
@ -249,13 +250,13 @@ class PortManager:
except OSError: except OSError:
old_port = port old_port = port
port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end) port = self.get_free_tcp_port(project, port_range_start=port_range_start, port_range_end=port_range_end)
msg = "TCP port {} already in use on host {}. Port has been replaced by {}".format(old_port, self._console_host, port) msg = f"TCP port {old_port} already in use on host {self._console_host}. Port has been replaced by {port}"
log.debug(msg) log.debug(msg)
return port return port
self._used_tcp_ports.add(port) self._used_tcp_ports.add(port)
project.record_tcp_port(port) project.record_tcp_port(port)
log.debug("TCP port {} has been reserved".format(port)) log.debug(f"TCP port {port} has been reserved")
return port return port
def release_tcp_port(self, port, project): def release_tcp_port(self, port, project):
@ -269,7 +270,7 @@ class PortManager:
if port in self._used_tcp_ports: if port in self._used_tcp_ports:
self._used_tcp_ports.remove(port) self._used_tcp_ports.remove(port)
project.remove_tcp_port(port) project.remove_tcp_port(port)
log.debug("TCP port {} has been released".format(port)) log.debug(f"TCP port {port} has been released")
def get_free_udp_port(self, project): def get_free_udp_port(self, project):
""" """
@ -285,7 +286,7 @@ class PortManager:
self._used_udp_ports.add(port) self._used_udp_ports.add(port)
project.record_udp_port(port) project.record_udp_port(port)
log.debug("UDP port {} has been allocated".format(port)) log.debug(f"UDP port {port} has been allocated")
return port return port
def reserve_udp_port(self, port, project): def reserve_udp_port(self, port, project):
@ -297,9 +298,12 @@ class PortManager:
""" """
if port in self._used_udp_ports: if port in self._used_udp_ports:
raise HTTPConflict(text="UDP port {} already in use on host {}".format(port, self._console_host)) raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"UDP port {port} already in use on host {self._console_host}")
if port < self._udp_port_range[0] or port > self._udp_port_range[1]: if port < self._udp_port_range[0] or port > self._udp_port_range[1]:
raise HTTPConflict(text="UDP port {} is outside the range {}-{}".format(port, self._udp_port_range[0], self._udp_port_range[1])) raise HTTPException(status_code=status.HTTP_409_CONFLICT,
detail=f"UDP port {port} is outside the range "
f"{self._udp_port_range[0]}-{self._udp_port_range[1]}")
self._used_udp_ports.add(port) self._used_udp_ports.add(port)
project.record_udp_port(port) project.record_udp_port(port)
log.debug("UDP port {} has been reserved".format(port)) log.debug("UDP port {} has been reserved".format(port))
@ -315,4 +319,4 @@ class PortManager:
if port in self._used_udp_ports: if port in self._used_udp_ports:
self._used_udp_ports.remove(port) self._used_udp_ports.remove(port)
project.remove_udp_port(port) project.remove_udp_port(port)
log.debug("UDP port {} has been released".format(port)) log.debug(f"UDP port {port} has been released")

View File

@ -19,12 +19,12 @@ import os
import json import json
import uuid import uuid
import asyncio import asyncio
import aiohttp
from .appliance import Appliance from .appliance import Appliance
from ..config import Config from ..config import Config
from ..utils.asyncio import locking from ..utils.asyncio import locking
from ..utils.get_resource import get_resource from ..utils.get_resource import get_resource
from ..utils.http_client import HTTPClient
from .controller_error import ControllerError from .controller_error import ControllerError
import logging import logging
@ -142,20 +142,19 @@ class ApplianceManager:
""" """
symbol_url = "https://raw.githubusercontent.com/GNS3/gns3-registry/master/symbols/{}".format(symbol) symbol_url = "https://raw.githubusercontent.com/GNS3/gns3-registry/master/symbols/{}".format(symbol)
async with aiohttp.ClientSession() as session: async with HTTPClient.get(symbol_url) as response:
async with session.get(symbol_url) as response: if response.status != 200:
if response.status != 200: log.warning("Could not retrieve appliance symbol {} from GitHub due to HTTP error code {}".format(symbol, response.status))
log.warning("Could not retrieve appliance symbol {} from GitHub due to HTTP error code {}".format(symbol, response.status)) else:
else: try:
try: symbol_data = await response.read()
symbol_data = await response.read() log.info("Saving {} symbol to {}".format(symbol, destination_path))
log.info("Saving {} symbol to {}".format(symbol, destination_path)) with open(destination_path, 'wb') as f:
with open(destination_path, 'wb') as f: f.write(symbol_data)
f.write(symbol_data) except asyncio.TimeoutError:
except asyncio.TimeoutError: log.warning("Timeout while downloading '{}'".format(symbol_url))
log.warning("Timeout while downloading '{}'".format(symbol_url)) except OSError as e:
except OSError as e: log.warning("Could not write appliance symbol '{}': {}".format(destination_path, e))
log.warning("Could not write appliance symbol '{}': {}".format(destination_path, e))
@locking @locking
async def download_appliances(self): async def download_appliances(self):
@ -168,40 +167,40 @@ class ApplianceManager:
if self._appliances_etag: if self._appliances_etag:
log.info("Checking if appliances are up-to-date (ETag {})".format(self._appliances_etag)) log.info("Checking if appliances are up-to-date (ETag {})".format(self._appliances_etag))
headers["If-None-Match"] = self._appliances_etag headers["If-None-Match"] = self._appliances_etag
async with aiohttp.ClientSession() as session:
async with session.get('https://api.github.com/repos/GNS3/gns3-registry/contents/appliances', headers=headers) as response: async with HTTPClient.get('https://api.github.com/repos/GNS3/gns3-registry/contents/appliances', headers=headers) as response:
if response.status == 304: if response.status == 304:
log.info("Appliances are already up-to-date (ETag {})".format(self._appliances_etag)) log.info("Appliances are already up-to-date (ETag {})".format(self._appliances_etag))
return return
elif response.status != 200: elif response.status != 200:
raise ControllerError("Could not retrieve appliances from GitHub due to HTTP error code {}".format(response.status)) raise ControllerError("Could not retrieve appliances from GitHub due to HTTP error code {}".format(response.status))
etag = response.headers.get("ETag") etag = response.headers.get("ETag")
if etag: if etag:
self._appliances_etag = etag self._appliances_etag = etag
from . import Controller from . import Controller
Controller.instance().save() Controller.instance().save()
json_data = await response.json() json_data = await response.json()
appliances_dir = get_resource('appliances') appliances_dir = get_resource('appliances')
for appliance in json_data: for appliance in json_data:
if appliance["type"] == "file": if appliance["type"] == "file":
appliance_name = appliance["name"] appliance_name = appliance["name"]
log.info("Download appliance file from '{}'".format(appliance["download_url"])) log.info("Download appliance file from '{}'".format(appliance["download_url"]))
async with session.get(appliance["download_url"]) as response: async with HTTPClient.get(appliance["download_url"]) as response:
if response.status != 200: if response.status != 200:
log.warning("Could not download '{}' due to HTTP error code {}".format(appliance["download_url"], response.status)) log.warning("Could not download '{}' due to HTTP error code {}".format(appliance["download_url"], response.status))
continue continue
try: try:
appliance_data = await response.read() appliance_data = await response.read()
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning("Timeout while downloading '{}'".format(appliance["download_url"])) log.warning("Timeout while downloading '{}'".format(appliance["download_url"]))
continue continue
path = os.path.join(appliances_dir, appliance_name) path = os.path.join(appliances_dir, appliance_name)
try: try:
log.info("Saving {} file to {}".format(appliance_name, path)) log.info("Saving {} file to {}".format(appliance_name, path))
with open(path, 'wb') as f: with open(path, 'wb') as f:
f.write(appliance_data) f.write(appliance_data)
except OSError as e: except OSError as e:
raise ControllerError("Could not write appliance file '{}': {}".format(path, e)) raise ControllerError("Could not write appliance file '{}': {}".format(path, e))
except ValueError as e: except ValueError as e:
raise ControllerError("Could not read appliances information from GitHub: {}".format(e)) raise ControllerError("Could not read appliances information from GitHub: {}".format(e))

View File

@ -63,7 +63,9 @@ class Compute:
A GNS3 compute. A GNS3 compute.
""" """
def __init__(self, compute_id, controller=None, protocol="http", host="localhost", port=3080, user=None, password=None, name=None, console_host=None): def __init__(self, compute_id, controller=None, protocol="http", host="localhost", port=3080, user=None,
password=None, name=None, console_host=None):
self._http_session = None self._http_session = None
assert controller is not None assert controller is not None
log.info("Create compute %s", compute_id) log.info("Create compute %s", compute_id)
@ -103,14 +105,10 @@ class Compute:
def _session(self): def _session(self):
if self._http_session is None or self._http_session.closed is True: if self._http_session is None or self._http_session.closed is True:
self._http_session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=None, force_close=True)) connector = aiohttp.TCPConnector(force_close=True)
self._http_session = aiohttp.ClientSession(connector=connector)
return self._http_session return self._http_session
#def __del__(self):
#
# if self._http_session:
# self._http_session.close()
def _set_auth(self, user, password): def _set_auth(self, user, password):
""" """
Set authentication parameters Set authentication parameters
@ -466,7 +464,7 @@ class Compute:
elif response.type == aiohttp.WSMsgType.CLOSED: elif response.type == aiohttp.WSMsgType.CLOSED:
pass pass
break break
except aiohttp.client_exceptions.ClientResponseError as e: except aiohttp.ClientError as e:
log.error("Client response error received on compute '{}' WebSocket '{}': {}".format(self._id, ws_url,e)) log.error("Client response error received on compute '{}' WebSocket '{}': {}".format(self._id, ws_url,e))
finally: finally:
self._connected = False self._connected = False
@ -503,8 +501,7 @@ class Compute:
async def _run_http_query(self, method, path, data=None, timeout=20, raw=False): async def _run_http_query(self, method, path, data=None, timeout=20, raw=False):
with async_timeout.timeout(timeout): with async_timeout.timeout(timeout):
url = self._getUrl(path) url = self._getUrl(path)
headers = {} headers = {'content-type': 'application/json'}
headers['content-type'] = 'application/json'
chunked = None chunked = None
if data == {}: if data == {}:
data = None data = None
@ -579,7 +576,7 @@ class Compute:
return response return response
async def get(self, path, **kwargs): async def get(self, path, **kwargs):
return (await self.http_query("GET", path, **kwargs)) return await self.http_query("GET", path, **kwargs)
async def post(self, path, data={}, **kwargs): async def post(self, path, data={}, **kwargs):
response = await self.http_query("POST", path, data, **kwargs) response = await self.http_query("POST", path, data, **kwargs)
@ -600,15 +597,13 @@ class Compute:
action = "/{}/{}".format(type, path) action = "/{}/{}".format(type, path)
res = await self.http_query(method, action, data=data, timeout=None) res = await self.http_query(method, action, data=data, timeout=None)
except aiohttp.ServerDisconnectedError: except aiohttp.ServerDisconnectedError:
log.error("Connection lost to %s during %s %s", self._id, method, action) raise ControllerError(f"Connection lost to {self._id} during {method} {action}")
raise aiohttp.web.HTTPGatewayTimeout()
return res.json return res.json
async def images(self, type): async def images(self, type):
""" """
Return the list of images available for this type on the compute node. Return the list of images available for this type on the compute node.
""" """
images = []
res = await self.http_query("GET", "/{}/images".format(type), timeout=None) res = await self.http_query("GET", "/{}/images".format(type), timeout=None)
images = res.json images = res.json
@ -641,11 +636,11 @@ class Compute:
:returns: Tuple (ip_for_this_compute, ip_for_other_compute) :returns: Tuple (ip_for_this_compute, ip_for_other_compute)
""" """
if other_compute == self: if other_compute == self:
return (self.host_ip, self.host_ip) return self.host_ip, self.host_ip
# Perhaps the user has correct network gateway, we trust him # Perhaps the user has correct network gateway, we trust him
if (self.host_ip not in ('0.0.0.0', '127.0.0.1') and other_compute.host_ip not in ('0.0.0.0', '127.0.0.1')): if self.host_ip not in ('0.0.0.0', '127.0.0.1') and other_compute.host_ip not in ('0.0.0.0', '127.0.0.1'):
return (self.host_ip, other_compute.host_ip) return self.host_ip, other_compute.host_ip
this_compute_interfaces = await self.interfaces() this_compute_interfaces = await self.interfaces()
other_compute_interfaces = await other_compute.interfaces() other_compute_interfaces = await other_compute.interfaces()
@ -675,6 +670,6 @@ class Compute:
other_network = ipaddress.ip_network("{}/{}".format(other_interface["ip_address"], other_interface["netmask"]), strict=False) other_network = ipaddress.ip_network("{}/{}".format(other_interface["ip_address"], other_interface["netmask"]), strict=False)
if this_network.overlaps(other_network): if this_network.overlaps(other_network):
return (this_interface["ip_address"], other_interface["ip_address"]) return this_interface["ip_address"], other_interface["ip_address"]
raise ValueError("No common subnet for compute {} and {}".format(self.name, other_compute.name)) raise ValueError("No common subnet for compute {} and {}".format(self.name, other_compute.name))

View File

@ -24,6 +24,7 @@ import socket
from .base_gns3_vm import BaseGNS3VM from .base_gns3_vm import BaseGNS3VM
from .gns3_vm_error import GNS3VMError from .gns3_vm_error import GNS3VMError
from gns3server.utils import parse_version from gns3server.utils import parse_version
from gns3server.utils.http_client import HTTPClient
from gns3server.utils.asyncio import wait_run_in_executor from gns3server.utils.asyncio import wait_run_in_executor
from ...compute.virtualbox import ( from ...compute.virtualbox import (
@ -305,24 +306,24 @@ class VirtualBoxGNS3VM(BaseGNS3VM):
second to a GNS3 endpoint in order to get the list of the interfaces and second to a GNS3 endpoint in order to get the list of the interfaces and
their IP and after that match it with VirtualBox host only. their IP and after that match it with VirtualBox host only.
""" """
remaining_try = 300 remaining_try = 300
while remaining_try > 0: while remaining_try > 0:
async with aiohttp.ClientSession() as session: try:
try: async with HTTPClient.get(f"http://127.0.0.1:{api_port}/v2/compute/network/interfaces") as resp:
async with session.get('http://127.0.0.1:{}/v2/compute/network/interfaces'.format(api_port)) as resp: if resp.status < 300:
if resp.status < 300: try:
try: json_data = await resp.json()
json_data = await resp.json() if json_data:
if json_data: for interface in json_data:
for interface in json_data: if "name" in interface and interface["name"] == "eth{}".format(
if "name" in interface and interface["name"] == "eth{}".format( hostonly_interface_number - 1):
hostonly_interface_number - 1): if "ip_address" in interface and len(interface["ip_address"]) > 0:
if "ip_address" in interface and len(interface["ip_address"]) > 0: return interface["ip_address"]
return interface["ip_address"] except ValueError:
except ValueError: pass
pass except (OSError, aiohttp.ClientError, TimeoutError, asyncio.TimeoutError):
except (OSError, aiohttp.ClientError, TimeoutError, asyncio.TimeoutError): pass
pass
remaining_try -= 1 remaining_try -= 1
await asyncio.sleep(1) await asyncio.sleep(1)
raise GNS3VMError("Could not find guest IP address for {}".format(self.vmname)) raise GNS3VMError("Could not find guest IP address for {}".format(self.vmname))

View File

@ -19,8 +19,8 @@
API endpoints for links. API endpoints for links.
""" """
import aiohttp
import multidict import multidict
import aiohttp
from fastapi import APIRouter, Depends, Request, status from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@ -31,9 +31,13 @@ from uuid import UUID
from gns3server.controller import Controller from gns3server.controller import Controller
from gns3server.controller.controller_error import ControllerError from gns3server.controller.controller_error import ControllerError
from gns3server.controller.link import Link from gns3server.controller.link import Link
from gns3server.utils.http_client import HTTPClient
from gns3server.endpoints.schemas.common import ErrorMessage from gns3server.endpoints.schemas.common import ErrorMessage
from gns3server.endpoints import schemas from gns3server.endpoints import schemas
import logging
log = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
responses = { responses = {
@ -201,12 +205,13 @@ async def pcap(request: Request, link: Link = Depends(dep_link)):
async def compute_pcap_stream(): async def compute_pcap_stream():
connector = aiohttp.TCPConnector(limit=None, force_close=True) try:
async with aiohttp.ClientSession(connector=connector, headers=headers) as session: async with HTTPClient.request(request.method, pcap_streaming_url, timeout=None, data=body) as response:
async with session.request(request.method, pcap_streaming_url, timeout=None, data=body) as compute_response: async for data in response.content.iter_any():
async for data in compute_response.content.iter_any():
if not data: if not data:
break break
yield data yield data
except aiohttp.ClientError as e:
raise ControllerError(f"Client error received when receiving pcap stream from compute: {e}")
return StreamingResponse(compute_pcap_stream(), media_type="application/vnd.tcpdump.pcap") return StreamingResponse(compute_pcap_stream(), media_type="application/vnd.tcpdump.pcap")

View File

@ -32,6 +32,7 @@ from gns3server.controller import Controller
from gns3server.controller.node import Node from gns3server.controller.node import Node
from gns3server.controller.project import Project from gns3server.controller.project import Project
from gns3server.utils import force_unix_path from gns3server.utils import force_unix_path
from gns3server.utils.http_client import HTTPClient
from gns3server.controller.controller_error import ControllerForbiddenError from gns3server.controller.controller_error import ControllerForbiddenError
from gns3server.endpoints.schemas.common import ErrorMessage from gns3server.endpoints.schemas.common import ErrorMessage
from gns3server.endpoints import schemas from gns3server.endpoints import schemas
@ -400,18 +401,17 @@ async def ws_console(websocket: WebSocket, node: Node = Depends(dep_node)):
try: try:
# receive WebSocket data from compute console WebSocket and forward to client. # receive WebSocket data from compute console WebSocket and forward to client.
async with aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=None, force_close=True)) as session: async with HTTPClient.get_client().ws_connect(ws_console_compute_url) as ws_console_compute:
async with session.ws_connect(ws_console_compute_url) as ws_console_compute: asyncio.ensure_future(ws_receive(ws_console_compute))
asyncio.ensure_future(ws_receive(ws_console_compute)) async for msg in ws_console_compute:
async for msg in ws_console_compute: if msg.type == aiohttp.WSMsgType.TEXT:
if msg.type == aiohttp.WSMsgType.TEXT: await websocket.send_text(msg.data)
await websocket.send_text(msg.data) elif msg.type == aiohttp.WSMsgType.BINARY:
elif msg.type == aiohttp.WSMsgType.BINARY: await websocket.send_bytes(msg.data)
await websocket.send_bytes(msg.data) elif msg.type == aiohttp.WSMsgType.ERROR:
elif msg.type == aiohttp.WSMsgType.ERROR: break
break except aiohttp.ClientError as e:
except aiohttp.client_exceptions.ClientResponseError as e: log.error(f"Client error received when forwarding to compute console WebSocket: {e}")
log.error(f"Client response error received when forwarding to compute console WebSocket: {e}")
@router.post("/console/reset", @router.post("/console/reset",

View File

@ -29,7 +29,7 @@ import time
import logging import logging
log = logging.getLogger() log = logging.getLogger()
from fastapi import APIRouter, Depends, Request, Body, Query, HTTPException, status, WebSocket, WebSocketDisconnect from fastapi import APIRouter, Depends, Request, Body, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
from websockets.exceptions import ConnectionClosed, WebSocketException from websockets.exceptions import ConnectionClosed, WebSocketException

View File

@ -0,0 +1,69 @@
#!/usr/bin/env python
#
# Copyright (C) 2020 GNS3 Technologies Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import aiohttp
import socket
import logging
log = logging.getLogger(__name__)
class HTTPClient:
"""
HTTP client for request to computes and external services.
"""
_aiohttp_client: aiohttp.ClientSession = None
@classmethod
def get_client(cls) -> aiohttp.ClientSession:
if cls._aiohttp_client is None:
cls._aiohttp_client = aiohttp.ClientSession(connector=aiohttp.TCPConnector(family=socket.AF_INET))
return cls._aiohttp_client
@classmethod
async def close_session(cls):
if cls._aiohttp_client:
await cls._aiohttp_client.close()
cls._aiohttp_client = None
@classmethod
def request(cls, method: str, url: str, user: str = None, password: str = None, **kwargs):
client = cls.get_client()
basic_auth = None
if user:
if not password:
password = ""
try:
basic_auth = aiohttp.BasicAuth(user, password, "utf-8")
except ValueError as e:
log.error(f"Basic authentication set-up error: {e}")
return client.request(method, url, auth=basic_auth, **kwargs)
@classmethod
def get(cls, path, **kwargs):
return cls.request("GET", path, **kwargs)
@classmethod
def post(cls, path, **kwargs):
return cls.request("POST", path, **kwargs)
@classmethod
def put(cls, path, **kwargs):
return cls.request("PUT", path, **kwargs)

View File

@ -18,12 +18,12 @@
import os import os
import sys import sys
import aiohttp
import socket import socket
import struct import struct
import psutil import psutil
from .windows_service import check_windows_service_is_running from .windows_service import check_windows_service_is_running
from gns3server.compute.compute_error import ComputeError
from gns3server.config import Config from gns3server.config import Config
if psutil.version_info < (3, 0, 0): if psutil.version_info < (3, 0, 0):
@ -162,7 +162,7 @@ def is_interface_up(interface):
return True return True
return False return False
except OSError as e: except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Exception when checking if {} is up: {}".format(interface, e)) raise ComputeError(f"Exception when checking if {interface} is up: {e}")
else: else:
# TODO: Windows & OSX support # TODO: Windows & OSX support
return True return True
@ -221,13 +221,13 @@ def interfaces():
results = get_windows_interfaces() results = get_windows_interfaces()
except ImportError: except ImportError:
message = "pywin32 module is not installed, please install it on the server to get the available interface names" message = "pywin32 module is not installed, please install it on the server to get the available interface names"
raise aiohttp.web.HTTPInternalServerError(text=message) raise ComputeError(message)
except Exception as e: except Exception as e:
log.error("uncaught exception {type}".format(type=type(e)), exc_info=1) log.error("uncaught exception {type}".format(type=type(e)), exc_info=1)
raise aiohttp.web.HTTPInternalServerError(text="uncaught exception: {}".format(e)) raise ComputeError(f"uncaught exception: {e}")
if service_installed is False: if service_installed is False:
raise aiohttp.web.HTTPInternalServerError(text="The Winpcap or Npcap is not installed or running") raise ComputeError("The Winpcap or Npcap is not installed or running")
# This interface have special behavior # This interface have special behavior
for result in results: for result in results:

View File

@ -16,8 +16,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os import os
import aiohttp
from fastapi import HTTPException, status
from ..config import Config from ..config import Config
@ -33,7 +33,7 @@ def get_default_project_directory():
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
except OSError as e: except OSError as e:
raise aiohttp.web.HTTPInternalServerError(text="Could not create project directory: {}".format(e)) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"Could not create project directory: {e}")
return path return path
@ -52,4 +52,4 @@ def check_path_allowed(path):
return return
if "local" in config and config.getboolean("local") is False: if "local" in config and config.getboolean("local") is False:
raise aiohttp.web.HTTPForbidden(text="The path is not allowed") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The path is not allowed")

View File

@ -19,7 +19,7 @@
Check for Windows service. Check for Windows service.
""" """
import aiohttp from gns3server.compute.compute_error import ComputeError
def check_windows_service_is_running(service_name): def check_windows_service_is_running(service_name):
@ -35,5 +35,5 @@ def check_windows_service_is_running(service_name):
if e.winerror == 1060: if e.winerror == 1060:
return False return False
else: else:
raise aiohttp.web.HTTPInternalServerError(text="Could not check if the {} service is running: {}".format(service_name, e.strerror)) raise ComputeError(f"Could not check if the {service_name} service is running: {e.strerror}")
return True return True

View File

@ -15,9 +15,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import aiohttp
import pytest import pytest
import uuid import uuid
from fastapi import HTTPException
from unittest.mock import patch from unittest.mock import patch
from gns3server.compute.port_manager import PortManager from gns3server.compute.port_manager import PortManager
@ -94,7 +95,7 @@ def test_reserve_udp_port():
pm = PortManager() pm = PortManager()
project = Project(project_id=str(uuid.uuid4())) project = Project(project_id=str(uuid.uuid4()))
pm.reserve_udp_port(20000, project) pm.reserve_udp_port(20000, project)
with pytest.raises(aiohttp.web.HTTPConflict): with pytest.raises(HTTPException):
pm.reserve_udp_port(20000, project) pm.reserve_udp_port(20000, project)
@ -102,7 +103,7 @@ def test_reserve_udp_port_outside_range():
pm = PortManager() pm = PortManager()
project = Project(project_id=str(uuid.uuid4())) project = Project(project_id=str(uuid.uuid4()))
with pytest.raises(aiohttp.web.HTTPConflict): with pytest.raises(HTTPException):
pm.reserve_udp_port(80, project) pm.reserve_udp_port(80, project)
@ -123,7 +124,7 @@ def test_find_unused_port():
def test_find_unused_port_invalid_range(): def test_find_unused_port_invalid_range():
with pytest.raises(aiohttp.web.HTTPConflict): with pytest.raises(HTTPException):
p = PortManager().find_unused_port(10000, 1000) p = PortManager().find_unused_port(10000, 1000)

View File

@ -74,8 +74,6 @@ async def test_compute_get(controller_api):
response = await controller_api.get("/computes/my_compute_id") response = await controller_api.get("/computes/my_compute_id")
assert response.status_code == 200 assert response.status_code == 200
print(response.json)
#assert response.json["protocol"] == "http"
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -17,8 +17,8 @@
import os import os
import pytest import pytest
import aiohttp
from fastapi import HTTPException
from gns3server.utils.path import check_path_allowed, get_default_project_directory from gns3server.utils.path import check_path_allowed, get_default_project_directory
@ -27,7 +27,7 @@ def test_check_path_allowed(config, tmpdir):
config.set("Server", "local", False) config.set("Server", "local", False)
config.set("Server", "projects_path", str(tmpdir)) config.set("Server", "projects_path", str(tmpdir))
with pytest.raises(aiohttp.web.HTTPForbidden): with pytest.raises(HTTPException):
check_path_allowed("/private") check_path_allowed("/private")
config.set("Server", "local", True) config.set("Server", "local", True)