diff --git a/src/agent/onefuzz-supervisor/src/config.rs b/src/agent/onefuzz-supervisor/src/config.rs index dbba31c6d..72bc968c4 100644 --- a/src/agent/onefuzz-supervisor/src/config.rs +++ b/src/agent/onefuzz-supervisor/src/config.rs @@ -115,7 +115,9 @@ impl Registration { if managed { let scaleset = onefuzz::machine_id::get_scaleset_name().await?; - url.query_pairs_mut().append_pair("scaleset_id", &scaleset); + url.query_pairs_mut() + .append_pair("scaleset_id", &scaleset) + .append_pair("version", env!("ONEFUZZ_VERSION")); } // The registration can fail because this call is made before the virtual machine scaleset is done provisioning // The authentication layer of the service will reject this request when that happens diff --git a/src/api-service/__app__/agent_registration/__init__.py b/src/api-service/__app__/agent_registration/__init__.py index 7f9d4a69f..1f244f6f9 100644 --- a/src/api-service/__app__/agent_registration/__init__.py +++ b/src/api-service/__app__/agent_registration/__init__.py @@ -14,7 +14,7 @@ from onefuzztypes.responses import AgentRegistration from ..onefuzzlib.agent_authorization import verify_token from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_name from ..onefuzzlib.azure.queue import get_queue_sas -from ..onefuzzlib.pools import Node, Pool +from ..onefuzzlib.pools import Node, NodeMessage, Pool from ..onefuzzlib.request import not_ok, ok, parse_uri @@ -44,7 +44,6 @@ def get(req: func.HttpRequest) -> func.HttpResponse: if isinstance(get_registration, Error): return not_ok(get_registration, context="agent registration") - # check if an existone registration exists agent_node = Node.get_by_machine_id(get_registration.machine_id) if agent_node is None: @@ -79,7 +78,6 @@ def post(req: func.HttpRequest) -> func.HttpResponse: registration_request = parse_uri(AgentRegistrationPost, req) if isinstance(registration_request, Error): return not_ok(registration_request, context="agent registration") - # check if an existone registration exists agent_node = Node.get_by_machine_id(registration_request.machine_id) pool = Pool.get_by_name(registration_request.pool_name) @@ -97,8 +95,13 @@ def post(req: func.HttpRequest) -> func.HttpResponse: pool_name=registration_request.pool_name, machine_id=registration_request.machine_id, scaleset_id=registration_request.scaleset_id, + version=registration_request.version ) agent_node.save() + elif agent_node.version.lower != registration_request.version: + NodeMessage.clear_messages(agent_node.machine_id) + agent_node.version = registration_request.version + agent_node.save() return create_registration_response(agent_node.machine_id, pool) diff --git a/src/api-service/__app__/onefuzzlib/pools.py b/src/api-service/__app__/onefuzzlib/pools.py index 758a9c3d2..6d1882099 100644 --- a/src/api-service/__app__/onefuzzlib/pools.py +++ b/src/api-service/__app__/onefuzzlib/pools.py @@ -26,6 +26,7 @@ from onefuzztypes.models import Scaleset as BASE_SCALESET from onefuzztypes.models import ( ScalesetNodeState, ScalesetSummary, + StopNodeCommand, WorkSet, WorkSetSummary, WorkUnitSummary, @@ -80,6 +81,28 @@ class Node(BASE_NODE, ORMMixin): query["pool_name"] = [pool_name] return cls.search(query=query) + @classmethod + def search_outdated( + cls, + *, + scaleset_id: Optional[UUID] = None, + states: Optional[List[NodeState]] = None, + pool_name: Optional[str] = None, + ) -> List["Node"]: + query: QueryFilter = {} + if scaleset_id: + query["scaleset_id"] = [scaleset_id] + if states: + query["state"] = states + if pool_name: + query["pool_name"] = [pool_name] + + # azure table query always return false when the column does not exist + # We write the query this way to allow us to get the nodes where the + # version is not defined as well as the nodes with a mismatched version + version_query = "not (version ne '%s')" % __version__ + return cls.search(query=query, raw_unchecked_filter=version_query) + @classmethod def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]: nodes = cls.search(query={"machine_id": [machine_id]}) @@ -132,9 +155,7 @@ class Node(BASE_NODE, ORMMixin): for node in nodes: if node.state not in NodeState.ready_for_reset(): logging.info( - "stopping task %s on machine_id:%s", - task_id, - node.machine_id, + "stopping task %s on machine_id:%s", task_id, node.machine_id, ) node.state = NodeState.done node.save() @@ -203,6 +224,12 @@ class NodeMessage(ORMMixin): client.commit_batch(cls.table_name(), batch) + @classmethod + def clear_messages(cls, agent_id: UUID) -> None: + messages = cls.get_messages(agent_id) + message_ids = [m.message_id for m in messages] + cls.delete_messages(agent_id, message_ids) + class Pool(BASE_POOL, ORMMixin): @classmethod @@ -569,13 +596,29 @@ class Scaleset(BASE_SCALESET, ORMMixin): nodes = Node.search_states( scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset() ) - if not nodes: + + outdated = Node.search_outdated( + scaleset_id=self.scaleset_id, + states=[NodeState.free], + ) + + if not (nodes or outdated): logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id) return False to_delete = [] to_reimage = [] + for node in outdated: + if node.version == "1.0.0": + to_reimage.append(node) + else: + stop_message = NodeMessage( + agent_id=node.machine_id, + message=NodeCommand(stop=StopNodeCommand()), + ) + stop_message.save() + for node in nodes: # delete nodes that are not waiting on the scaleset GC if not node.scaleset_node_exists(): @@ -779,8 +822,7 @@ class Scaleset(BASE_SCALESET, ORMMixin): break if not node_state: node_state = ScalesetNodeState( - machine_id=machine_id, - instance_id=instance_id, + machine_id=machine_id, instance_id=instance_id, ) self.nodes.append(node_state) diff --git a/src/pytypes/onefuzztypes/requests.py b/src/pytypes/onefuzztypes/requests.py index 2eeac2ccb..6de85f217 100644 --- a/src/pytypes/onefuzztypes/requests.py +++ b/src/pytypes/onefuzztypes/requests.py @@ -82,6 +82,7 @@ class AgentRegistrationPost(BaseRequest): pool_name: PoolName scaleset_id: Optional[UUID] machine_id: UUID + version: str class PoolCreate(BaseRequest):