Save the list of compute node

Fix #494
This commit is contained in:
Julien Duponchelle 2016-04-19 15:35:50 +02:00
parent 6463007ef1
commit f5e5cf5059
No known key found for this signature in database
GPG Key ID: CE8B29639E07F5E8
6 changed files with 137 additions and 6 deletions

View File

@ -15,12 +15,19 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os
import sys
import json
import asyncio import asyncio
import aiohttp import aiohttp
from ..config import Config from ..config import Config
from .project import Project from .project import Project
from .compute import Compute from .compute import Compute
from ..version import __version__
import logging
log = logging.getLogger(__name__)
class Controller: class Controller:
@ -30,6 +37,48 @@ class Controller:
self._computes = {} self._computes = {}
self._projects = {} self._projects = {}
if sys.platform.startswith("win"):
config_path = os.path.join(os.path.expandvars("%APPDATA%"), "GNS3")
else:
config_path = os.path.join(os.path.expanduser("~"), ".config", "GNS3")
self._config_file = os.path.join(config_path, "gns3_controller.conf")
def save(self):
"""
Save the controller configuration on disk
"""
data = {
"computes": [ {
"host": c.host,
"port": c.port,
"protocol": c.protocol,
"user": c.user,
"password": c.password,
"compute_id": c.id
} for c in self._computes.values() ],
"version": __version__
}
os.makedirs(os.path.dirname(self._config_file), exist_ok=True)
with open(self._config_file, 'w+') as f:
json.dump(data, f, indent=4)
@asyncio.coroutine
def load(self):
"""
Reload the controller configuration from disk
"""
if not os.path.exists(self._config_file):
return
try:
with open(self._config_file) as f:
data = json.load(f)
except OSError as e:
log.critical("Can not load %s: %s", self._config_file, str(e))
return
for c in data["computes"]:
compute_id = c.pop("compute_id")
yield from self.addCompute(compute_id, **c)
def isEnabled(self): def isEnabled(self):
""" """
:returns: True if current instance is the controller :returns: True if current instance is the controller
@ -47,6 +96,7 @@ class Controller:
if compute_id not in self._computes: if compute_id not in self._computes:
compute = Compute(compute_id=compute_id, controller=self, **kwargs) compute = Compute(compute_id=compute_id, controller=self, **kwargs)
self._computes[compute_id] = compute self._computes[compute_id] = compute
self.save()
return self._computes[compute_id] return self._computes[compute_id]
@property @property

View File

@ -87,6 +87,20 @@ class Compute:
""" """
return self._host return self._host
@property
def port(self):
"""
:returns: Compute port (integer)
"""
return self._port
@property
def protocol(self):
"""
:returns: Compute protocol (string)
"""
return self._protocol
@property @property
def user(self): def user(self):
return self._user return self._user

View File

@ -26,7 +26,6 @@ import datetime
import sys import sys
import locale import locale
import argparse import argparse
import asyncio
from gns3server.web.web_server import WebServer from gns3server.web.web_server import WebServer
from gns3server.web.logger import init_logger from gns3server.web.logger import init_logger
@ -35,6 +34,7 @@ from gns3server.config import Config
from gns3server.compute.project import Project from gns3server.compute.project import Project
from gns3server.crash_report import CrashReport from gns3server.crash_report import CrashReport
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -34,6 +34,8 @@ from .request_handler import RequestHandler
from ..config import Config from ..config import Config
from ..compute import MODULES from ..compute import MODULES
from ..compute.port_manager import PortManager from ..compute.port_manager import PortManager
from ..controller import Controller
# do not delete this import # do not delete this import
import gns3server.handlers import gns3server.handlers
@ -198,6 +200,9 @@ class WebServer:
# Asyncio will raise error if coroutine is not called # Asyncio will raise error if coroutine is not called
self._loop.set_debug(True) self._loop.set_debug(True)
if server_config.getboolean("controller"):
asyncio.async(Controller.instance().load())
for key, val in os.environ.items(): for key, val in os.environ.items():
log.debug("ENV %s=%s", key, val) log.debug("ENV %s=%s", key, val)

View File

@ -174,9 +174,16 @@ def ethernet_device():
@pytest.fixture @pytest.fixture
def controller(): def controller_config_path(tmpdir):
return str(tmpdir / "config" / "gns3_controller.conf")
@pytest.fixture
def controller(tmpdir, controller_config_path):
Controller._instance = None Controller._instance = None
return Controller.instance() controller = Controller.instance()
controller._config_file = controller_config_path
return controller
@pytest.fixture @pytest.fixture

View File

@ -15,8 +15,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import pytest import os
import uuid import uuid
import json
import pytest
import aiohttp import aiohttp
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -25,6 +27,44 @@ from gns3server.controller import Controller
from gns3server.controller.compute import Compute from gns3server.controller.compute import Compute
from gns3server.controller.project import Project from gns3server.controller.project import Project
from gns3server.config import Config from gns3server.config import Config
from gns3server.version import __version__
def test_save(controller, controller_config_path):
controller.save()
assert os.path.exists(controller_config_path)
with open(controller_config_path) as f:
data = json.load(f)
assert data["computes"] == []
assert data["version"] == __version__
def test_load(controller, controller_config_path, async_run):
controller.save()
with open(controller_config_path) as f:
data = json.load(f)
data["computes"] = [
{
"host": "localhost",
"port": 8000,
"protocol": "http",
"user": "admin",
"password": "root",
"compute_id": "test1"
}
]
with open(controller_config_path, "w+") as f:
json.dump(data, f)
async_run(controller.load())
assert len(controller.computes) == 1
assert controller.computes["test1"].__json__() == {
"compute_id": "test1",
"connected": False,
"host": "localhost",
"port": 8000,
"protocol": "http",
"user": "admin"
}
def test_isEnabled(controller): def test_isEnabled(controller):
@ -34,7 +74,7 @@ def test_isEnabled(controller):
assert controller.isEnabled() assert controller.isEnabled()
def test_addCompute(controller, async_run): def test_addCompute(controller, controller_config_path, async_run):
async_run(controller.addCompute("test1")) async_run(controller.addCompute("test1"))
assert len(controller.computes) == 1 assert len(controller.computes) == 1
async_run(controller.addCompute("test1")) async_run(controller.addCompute("test1"))
@ -42,9 +82,24 @@ def test_addCompute(controller, async_run):
async_run(controller.addCompute("test2")) async_run(controller.addCompute("test2"))
assert len(controller.computes) == 2 assert len(controller.computes) == 2
def test_addComputeConfigFile(controller, controller_config_path, async_run):
async_run(controller.addCompute("test1"))
assert len(controller.computes) == 1
with open(controller_config_path) as f:
data = json.load(f)
assert data["computes"] == [
{
'compute_id': 'test1',
'host': 'localhost',
'port': 8000,
'protocol': 'http',
'user': None,
'password': None
}
]
def test_getCompute(controller, async_run): def test_getCompute(controller, async_run):
compute = async_run(controller.addCompute("test1")) compute = async_run(controller.addCompute("test1"))
assert controller.getCompute("test1") == compute assert controller.getCompute("test1") == compute