diff --git a/src/agent/onefuzz-supervisor/src/coordinator.rs b/src/agent/onefuzz-supervisor/src/coordinator.rs index 02699c877..c19911a6e 100644 --- a/src/agent/onefuzz-supervisor/src/coordinator.rs +++ b/src/agent/onefuzz-supervisor/src/coordinator.rs @@ -25,6 +25,7 @@ pub enum NodeCommand { AddSshKey(SshKeyInfo), StopTask(StopTask), Stop {}, + StopIfFree {}, } #[derive(Debug, Deserialize, Eq, PartialEq, Serialize)] diff --git a/src/agent/onefuzz-supervisor/src/scheduler.rs b/src/agent/onefuzz-supervisor/src/scheduler.rs index 4dd1dc2ee..a3824410d 100644 --- a/src/agent/onefuzz-supervisor/src/scheduler.rs +++ b/src/agent/onefuzz-supervisor/src/scheduler.rs @@ -71,6 +71,15 @@ impl Scheduler { }; *self = state.into(); } + NodeCommand::StopIfFree {} => { + if let Scheduler::Free(_) = self { + let cause = DoneCause::Stopped; + let state = State { + ctx: Done { cause }, + }; + *self = state.into(); + } + } } Ok(()) diff --git a/src/api-service/__app__/onefuzzlib/versions.py b/src/api-service/__app__/onefuzzlib/versions.py index 7b822fe10..19edce746 100644 --- a/src/api-service/__app__/onefuzzlib/versions.py +++ b/src/api-service/__app__/onefuzzlib/versions.py @@ -6,6 +6,7 @@ import os from typing import Dict +import semver from memoization import cached from onefuzztypes.responses import Version @@ -29,3 +30,8 @@ def versions() -> Dict[str, Version]: version=__version__, ) return {"onefuzz": entry} + + +def is_minimum_version(*, version: str, minimum: str) -> bool: + # check if version is at least (or higher) than minimum + return bool(semver.VersionInfo.parse(version).compare(minimum) >= 0) diff --git a/src/api-service/__app__/onefuzzlib/workers/nodes.py b/src/api-service/__app__/onefuzzlib/workers/nodes.py index 7858688f5..2919ea765 100644 --- a/src/api-service/__app__/onefuzzlib/workers/nodes.py +++ b/src/api-service/__app__/onefuzzlib/workers/nodes.py @@ -16,7 +16,12 @@ from onefuzztypes.events import ( ) from onefuzztypes.models import Error from onefuzztypes.models import Node as BASE_NODE -from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey +from onefuzztypes.models import ( + NodeAssignment, + NodeCommand, + NodeCommandAddSshKey, + NodeCommandStopIfFree, +) from onefuzztypes.models import NodeTasks as BASE_NODE_TASK from onefuzztypes.models import Result, StopNodeCommand, StopTaskNodeCommand from onefuzztypes.primitives import PoolName @@ -26,6 +31,7 @@ from ..__version__ import __version__ from ..azure.vmss import get_instance_id from ..events import send_event from ..orm import MappingIntStrAny, ORMMixin, QueryFilter +from ..versions import is_minimum_version NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) @@ -338,6 +344,11 @@ class Node(BASE_NODE, ORMMixin): if not self.reimage_requested and not self.delete_requested: logging.info("setting reimage_requested: %s", self.machine_id) self.reimage_requested = True + + # if we're going to reimage, make sure the node doesn't pick up new work + # too. + self.send_stop_if_free() + self.save() def add_ssh_public_key(self, public_key: str) -> Result[None]: @@ -355,6 +366,10 @@ class Node(BASE_NODE, ORMMixin): ) return None + def send_stop_if_free(self) -> None: + if is_minimum_version(version=self.version, minimum="2.16.1"): + self.send_message(NodeCommand(stop_if_free=NodeCommandStopIfFree())) + def stop(self, done: bool = False) -> None: self.to_reimage(done=done) self.send_message(NodeCommand(stop=StopNodeCommand())) diff --git a/src/api-service/__app__/onefuzzlib/workers/scalesets.py b/src/api-service/__app__/onefuzzlib/workers/scalesets.py index a8b3c1984..f05b42f69 100644 --- a/src/api-service/__app__/onefuzzlib/workers/scalesets.py +++ b/src/api-service/__app__/onefuzzlib/workers/scalesets.py @@ -440,10 +440,19 @@ class Scaleset(BASE_SCALESET, ORMMixin): return def _resize_shrink(self, to_remove: int) -> None: + logging.info( + SCALESET_LOG_PREFIX + "shrinking scaleset. scaleset_id:%s to_remove:%d", + self.scaleset_id, + to_remove, + ) queue = ScalesetShrinkQueue(self.scaleset_id) for _ in range(to_remove): queue.add_entry() + nodes = Node.search_states(scaleset_id=self.scaleset_id) + for node in nodes: + node.send_stop_if_free() + def resize(self) -> None: # no longer needing to resize if self.state != ScalesetState.resize: diff --git a/src/api-service/__app__/requirements.txt b/src/api-service/__app__/requirements.txt index 0d3af738a..4683239e9 100644 --- a/src/api-service/__app__/requirements.txt +++ b/src/api-service/__app__/requirements.txt @@ -30,5 +30,6 @@ memoization~=0.3.1 github3.py~=1.3.0 typing-extensions~=3.7.4.3 jsonpatch==1.28 +semver==2.13.0 # onefuzz types version is set during build onefuzztypes==0.0.0 diff --git a/src/api-service/mypy.ini b/src/api-service/mypy.ini index 330b7bff9..1248542ba 100644 --- a/src/api-service/mypy.ini +++ b/src/api-service/mypy.ini @@ -36,4 +36,7 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-jsonpatch.*] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True + +[mypy-semver.*] +ignore_missing_imports = True diff --git a/src/api-service/tests/test_version_check.py b/src/api-service/tests/test_version_check.py new file mode 100755 index 000000000..85023b23c --- /dev/null +++ b/src/api-service/tests/test_version_check.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +from __app__.onefuzzlib.versions import is_minimum_version + + +class TestMinVersion(unittest.TestCase): + def test_basic(self) -> None: + self.assertEqual(is_minimum_version(version="1.0.0", minimum="1.0.0"), True) + self.assertEqual(is_minimum_version(version="2.0.0", minimum="1.0.0"), True) + self.assertEqual(is_minimum_version(version="2.0.0", minimum="3.0.0"), False) + self.assertEqual(is_minimum_version(version="1.0.0", minimum="1.6.0"), False) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 2489abc3a..d3a2d1385 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -557,6 +557,10 @@ class NodeHeartbeatEntry(BaseModel): data: List[Dict[str, HeartbeatType]] +class NodeCommandStopIfFree(BaseModel): + pass + + class StopNodeCommand(BaseModel): pass @@ -573,6 +577,7 @@ class NodeCommand(EnumModel): stop: Optional[StopNodeCommand] stop_task: Optional[StopTaskNodeCommand] add_ssh_key: Optional[NodeCommandAddSshKey] + stop_if_free: Optional[NodeCommandStopIfFree] class NodeCommandEnvelope(BaseModel):