diff --git a/gns3server/controller/link.py b/gns3server/controller/link.py index 219cefe5..1c8dfa43 100644 --- a/gns3server/controller/link.py +++ b/gns3server/controller/link.py @@ -15,19 +15,24 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import os import re import uuid import asyncio +import logging +log = logging.getLogger(__name__) + + class Link: - def __init__(self, project, data_link_type="DLT_EN10MB"): + def __init__(self, project): self._id = str(uuid.uuid4()) self._vms = [] self._project = project - self._data_link_type = data_link_type self._capturing = False + self._capture_file_name = None @asyncio.coroutine def addVM(self, vm, adapter_number, port_number): @@ -55,29 +60,50 @@ class Link: raise NotImplementedError @asyncio.coroutine - def start_capture(self): + def start_capture(self, data_link_type="DLT_EN10MB", capture_file_name=None): """ Start capture on the link :returns: Capture object """ - raise NotImplementedError + self._capturing = True + self._capture_file_name = capture_file_name + self._streaming_pcap = asyncio.async(self._start_streaming_pcap()) + + @asyncio.coroutine + def _start_streaming_pcap(self): + """ + Dump the pcap file on disk + """ + stream = yield from self.read_pcap_from_source() + with open(self.capture_file_path, "wb+") as f: + while self._capturing: + # We read 1 bytes by 1 otherwise if the traffic stop the remaining data is not read + # this is slow + data = yield from stream.read(1) + if data: + f.write(data) + # Flush to disk otherwise the live is not really live + f.flush() + else: + break + yield from stream.close() @asyncio.coroutine def stop_capture(self): """ Stop capture on the link """ - raise NotImplementedError + self._capturing = False @asyncio.coroutine - def read_pcap(self): + def read_pcap_from_source(self): """ Return a FileStream of the Pcap from the compute node """ raise NotImplementedError - def capture_file_name(self): + def default_capture_file_name(self): """ :returns: File name for a capture on this link """ @@ -98,6 +124,16 @@ class Link: def capturing(self): return self._capturing + @property + def capture_file_path(self): + """ + Get the path of the capture + """ + if self._capture_file_name: + return os.path.join(self._project.captures_directory, self._capture_file_name) + else: + return None + def __json__(self): res = [] for side in self._vms: @@ -106,4 +142,8 @@ class Link: "adapter_number": side["adapter_number"], "port_number": side["port_number"] }) - return {"vms": res, "link_id": self._id, "capturing": self._capturing} + return { + "vms": res, "link_id": self._id, + "capturing": self._capturing, + "capture_file_name": self._capture_file_name + } diff --git a/gns3server/controller/udp_link.py b/gns3server/controller/udp_link.py index f335786f..fd4e9e1f 100644 --- a/gns3server/controller/udp_link.py +++ b/gns3server/controller/udp_link.py @@ -80,17 +80,19 @@ class UDPLink(Link): yield from vm2.delete("/adapters/{adapter_number}/ports/{port_number}/nio".format(adapter_number=adapter_number2, port_number=port_number2)) @asyncio.coroutine - def start_capture(self, data_link_type="DLT_EN10MB"): + def start_capture(self, data_link_type="DLT_EN10MB", capture_file_name=None): """ Start capture on a link """ + if not capture_file_name: + capture_file_name = self.default_capture_file_name() self._capture_vm = self._choose_capture_side() data = { - "capture_file_name": self.capture_file_name(), + "capture_file_name": capture_file_name, "data_link_type": data_link_type } yield from self._capture_vm["vm"].post("/adapters/{adapter_number}/ports/{port_number}/start_capture".format(adapter_number=self._capture_vm["adapter_number"], port_number=self._capture_vm["port_number"]), data=data) - self._capturing = True + yield from super().start_capture(data_link_type=data_link_type, capture_file_name=capture_file_name) @asyncio.coroutine def stop_capture(self): @@ -100,7 +102,7 @@ class UDPLink(Link): if self._capture_vm: yield from self._capture_vm["vm"].post("/adapters/{adapter_number}/ports/{port_number}/stop_capture".format(adapter_number=self._capture_vm["adapter_number"], port_number=self._capture_vm["port_number"])) self._capture_vm = None - self._capturing = False + yield from super().stop_capture() def _choose_capture_side(self): """ @@ -124,10 +126,10 @@ class UDPLink(Link): raise aiohttp.web.HTTPConflict(text="Capture is not supported for this link") @asyncio.coroutine - def read_pcap(self): + def read_pcap_from_source(self): """ Return a FileStream of the Pcap from the compute node """ if self._capture_vm: compute = self._capture_vm["vm"].compute - return compute.streamFile(self._project, "tmp/captures/" + self.capture_file_name()) + return compute.streamFile(self._project, "tmp/captures/" + self._capture_file_name) diff --git a/gns3server/handlers/api/controller/link_handler.py b/gns3server/handlers/api/controller/link_handler.py index 58291a85..2f068226 100644 --- a/gns3server/handlers/api/controller/link_handler.py +++ b/gns3server/handlers/api/controller/link_handler.py @@ -73,7 +73,7 @@ class LinkHandler: controller = Controller.instance() project = controller.getProject(request.match_info["project_id"]) link = project.getLink(request.match_info["link_id"]) - yield from link.start_capture(request.json.get("data_link_type", "DLT_EN10MB")) + yield from link.start_capture(data_link_type=request.json.get("data_link_type", "DLT_EN10MB"), capture_file_name=request.json.get("capture_file_name")) response.set_status(204) @classmethod @@ -136,20 +136,24 @@ class LinkHandler: project = controller.getProject(request.match_info["project_id"]) link = project.getLink(request.match_info["link_id"]) - content = yield from link.read_pcap() - if content is None: + if link.capture_file_path is None: raise aiohttp.web.HTTPNotFound(text="pcap file not found") - response.content_type = "application/vnd.tcpdump.pcap" - response.set_status(200) - response.enable_chunked_encoding() - # Very important: do not send a content length otherwise QT close the connection but curl can consume the Feed - response.content_length = None + try: + print(link.capture_file_path) + with open(link.capture_file_path, "rb") as f: - response.start(request) + response.content_type = "application/vnd.tcpdump.pcap" + response.set_status(200) + response.enable_chunked_encoding() + # Very important: do not send a content length otherwise QT close the connection but curl can consume the Feed + response.content_length = None + response.start(request) - while True: - chunk = yield from content.read(4096) - if not chunk: - yield from asyncio.sleep(0.1) - yield from response.write(chunk) + while True: + chunk = f.read(4096) + if not chunk: + break + yield from response.write(chunk) + except OSError: + raise aiohttp.web.HTTPNotFound(text="pcap file {} not found or not accessible".format(link.capture_file_path)) diff --git a/gns3server/schemas/link.py b/gns3server/schemas/link.py index 4d0db2f1..a82ae216 100644 --- a/gns3server/schemas/link.py +++ b/gns3server/schemas/link.py @@ -55,9 +55,13 @@ LINK_OBJECT_SCHEMA = { } }, "capturing": { - "description": "Read only propertie. Is a capture running on the link", + "description": "Read only propertie. True if a capture running on the link", "type": "boolean" }, + "capture_file_name": { + "description": "Read only propertie. The name of the capture file if capture is running", + "type": ["string", "null"] + } }, "required": ["vms"], "additionalProperties": False @@ -72,6 +76,10 @@ LINK_CAPTURE_SCHEMA = { "data_link_type": { "description": "PCAP data link type (http://www.tcpdump.org/linktypes.html)", "enum": ["DLT_ATM_RFC1483", "DLT_EN10MB", "DLT_FRELAY", "DLT_C_HDLC"] + }, + "capture_file_name": { + "description": "Read only propertie. The name of the capture file if capture is running", + "type": "string" } }, "additionalProperties": False diff --git a/tests/controller/test_link.py b/tests/controller/test_link.py index 6c7610c1..eca2ff6b 100644 --- a/tests/controller/test_link.py +++ b/tests/controller/test_link.py @@ -15,7 +15,9 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import os import pytest +import asyncio from unittest.mock import MagicMock @@ -24,6 +26,8 @@ from gns3server.controller.vm import VM from gns3server.controller.compute import Compute from gns3server.controller.project import Project +from tests.utils import AsyncioBytesIO + @pytest.fixture def project(): @@ -35,6 +39,17 @@ def compute(): return Compute("example.com", controller=MagicMock()) +@pytest.fixture +def link(async_run, project, compute): + vm1 = VM(project, compute) + vm2 = VM(project, compute) + + link = Link(project) + async_run(link.addVM(vm1, 0, 4)) + async_run(link.addVM(vm2, 1, 3)) + return link + + def test_addVM(async_run, project, compute): vm1 = VM(project, compute) @@ -70,15 +85,33 @@ def test_json(async_run, project, compute): "port_number": 3 } ], - "capturing": False + "capturing": False, + "capture_file_name": None } -def test_capture_filename(project, compute, async_run): +def test_start_streaming_pcap(link, async_run, tmpdir, project): + @asyncio.coroutine + def fake_reader(): + output = AsyncioBytesIO() + yield from output.write(b"hello") + output.seek(0) + return output + + link._capture_file_name = "test.pcap" + link._capturing = True + link.read_pcap_from_source = fake_reader + async_run(link._start_streaming_pcap()) + with open(os.path.join(project.captures_directory, "test.pcap"), "rb") as f: + c = f.read() + assert c == b"hello" + + +def test_default_capture_file_name(project, compute, async_run): vm1 = VM(project, compute, name="Hello@") vm2 = VM(project, compute, name="w0.rld") link = Link(project) async_run(link.addVM(vm1, 0, 4)) async_run(link.addVM(vm2, 1, 3)) - assert link.capture_file_name() == "Hello_0-4_to_w0rld_1-3.pcap" + assert link.default_capture_file_name() == "Hello_0-4_to_w0rld_1-3.pcap" diff --git a/tests/controller/test_project.py b/tests/controller/test_project.py index 94119754..c6aaf64b 100644 --- a/tests/controller/test_project.py +++ b/tests/controller/test_project.py @@ -63,6 +63,12 @@ def test_changing_path_with_quote_not_allowed(tmpdir): p.path = str(tmpdir / "project\"53") +def test_captures_directory(tmpdir): + p = Project(path=str(tmpdir)) + assert p.captures_directory == str(tmpdir / "project-files" / "captures") + assert os.path.exists(p.captures_directory) + + def test_addVM(async_run): compute = MagicMock() project = Project() diff --git a/tests/controller/test_udp_link.py b/tests/controller/test_udp_link.py index 7bd21eb3..0cac6516 100644 --- a/tests/controller/test_udp_link.py +++ b/tests/controller/test_udp_link.py @@ -153,7 +153,7 @@ def test_capture(async_run, project): assert link.capturing compute1.post.assert_any_call("/projects/{}/iou/vms/{}/adapters/3/ports/1/start_capture".format(project.id, vm_iou.id), data={ - "capture_file_name": link.capture_file_name(), + "capture_file_name": link.default_capture_file_name(), "data_link_type": "DLT_EN10MB" }) @@ -163,7 +163,7 @@ def test_capture(async_run, project): compute1.post.assert_any_call("/projects/{}/iou/vms/{}/adapters/3/ports/1/stop_capture".format(project.id, vm_iou.id)) -def test_read_pcap(project, async_run): +def test_read_pcap_from_source(project, async_run): compute1 = MagicMock() link = UDPLink(project) @@ -173,5 +173,5 @@ def test_read_pcap(project, async_run): capture = async_run(link.start_capture()) assert link._capture_vm is not None - async_run(link.read_pcap()) - link._capture_vm["vm"].compute.streamFile.assert_called_with(project, "tmp/captures/" + link.capture_file_name()) + async_run(link.read_pcap_from_source()) + link._capture_vm["vm"].compute.streamFile.assert_called_with(project, "tmp/captures/" + link._capture_file_name) diff --git a/tests/handlers/api/controller/test_link.py b/tests/handlers/api/controller/test_link.py index 521b841e..d9d94aa1 100644 --- a/tests/handlers/api/controller/test_link.py +++ b/tests/handlers/api/controller/test_link.py @@ -97,11 +97,13 @@ def test_stop_capture(http_controller, tmpdir, project, compute, async_run): def test_pcap(http_controller, tmpdir, project, compute, async_run): link = Link(project) - link + link._capture_file_name = "test" + link._capturing = True + with open(link.capture_file_path, "w+") as f: + f.write("hello") project._links = {link.id: link} - with asyncio_patch("gns3server.controller.link.Link.read_pcap", return_value=None) as mock: - response = http_controller.get("/projects/{}/links/{}/pcap".format(project.id, link.id), example=True) - assert mock.called + response = http_controller.get("/projects/{}/links/{}/pcap".format(project.id, link.id), example=True) + assert response.body == b"hello" def test_delete_link(http_controller, tmpdir, project, compute, async_run): diff --git a/tests/utils.py b/tests/utils.py index a20dc98c..70ce6e01 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,6 +15,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import io import asyncio import unittest.mock @@ -87,3 +88,22 @@ class AsyncioMagicMock(unittest.mock.MagicMock): Original code: https://github.com/python/cpython/blob/121f86338111e49c547a55eb7f26db919bfcbde9/Lib/unittest/mock.py """ return AsyncioMagicMock(**kw) + + +class AsyncioBytesIO(io.BytesIO): + """ + An async wrapper arround io.BytesIO to fake an + async network connection + """ + + @asyncio.coroutine + def read(self, length=-1): + return super().read(length) + + @asyncio.coroutine + def write(self, data): + return super().write(data) + + @asyncio.coroutine + def close(self): + return super().close()