diff --git a/src/api-service/__app__/agent_can_schedule/__init__.py b/src/api-service/__app__/agent_can_schedule/__init__.py new file mode 100644 index 000000000..3f730bd69 --- /dev/null +++ b/src/api-service/__app__/agent_can_schedule/__init__.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging + +import azure.functions as func +from onefuzztypes.enums import ErrorCode, TaskState +from onefuzztypes.models import Error, NodeCommand, StopNodeCommand +from onefuzztypes.requests import CanScheduleRequest +from onefuzztypes.responses import CanSchedule + +from ..onefuzzlib.agent_authorization import verify_token +from ..onefuzzlib.pools import Node, NodeMessage +from ..onefuzzlib.request import not_ok, ok, parse_uri +from ..onefuzzlib.tasks.main import Task + + +def post(req: func.HttpRequest) -> func.HttpResponse: + request = parse_uri(CanScheduleRequest, req) + if isinstance(request, Error): + return not_ok(request, context="CanScheduleRequest") + + node = Node.get_by_machine_id(request.machine_id) + if not node: + return not_ok( + Error(code=ErrorCode.UNABLE_TO_FIND, errors=["unable to find node"]), + context=request.machine_id, + ) + + allowed = True + work_stopped = False + if node.is_outdated: + logging.info( + "received can_schedule request from outdated node '%s' version '%s'", + node.machine_id, + node.version, + ) + allowed = False + stop_message = NodeMessage( + agent_id=node.machine_id, message=NodeCommand(stop=StopNodeCommand()), + ) + stop_message.save() + + task = Task.get_by_task_id(request.task_id) + + work_stopped = isinstance(task, Error) or (task.state != TaskState.scheduled) + return ok(CanSchedule(allowed=allowed, work_stopped=work_stopped)) + + +def main(req: func.HttpRequest) -> func.HttpResponse: + if req.method == "POST": + m = post + else: + raise Exception("invalid method") + + return verify_token(req, m) diff --git a/src/api-service/__app__/agent_can_schedule/function.json b/src/api-service/__app__/agent_can_schedule/function.json new file mode 100644 index 000000000..dddca408d --- /dev/null +++ b/src/api-service/__app__/agent_can_schedule/function.json @@ -0,0 +1,20 @@ +{ + "scriptFile": "__init__.py", + "bindings": [ + { + "authLevel": "anonymous", + "type": "httpTrigger", + "direction": "in", + "name": "req", + "methods": [ + "post" + ], + "route": "agents/can_schedule" + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} diff --git a/src/api-service/__app__/onefuzzlib/pools.py b/src/api-service/__app__/onefuzzlib/pools.py index 9bd54e54b..758a9c3d2 100644 --- a/src/api-service/__app__/onefuzzlib/pools.py +++ b/src/api-service/__app__/onefuzzlib/pools.py @@ -33,6 +33,7 @@ from onefuzztypes.models import ( from onefuzztypes.primitives import PoolName, Region from pydantic import Field +from .__version__ import __version__ from .azure.auth import build_auth from .azure.creds import get_fuzz_storage from .azure.image import get_os @@ -138,6 +139,9 @@ class Node(BASE_NODE, ORMMixin): node.state = NodeState.done node.save() + def is_outdated(self) -> bool: + return self.version != __version__ + class NodeTasks(BASE_NODE_TASK, ORMMixin): @classmethod @@ -178,7 +182,7 @@ class NodeMessage(ORMMixin): @classmethod def key_fields(cls) -> Tuple[str, str]: - return ("agent_id", "create_date") + return ("agent_id", "message_id") @classmethod def get_messages( diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 145f09ad1..87d705f4b 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -398,6 +398,7 @@ class Node(BaseModel): state: NodeState = Field(default=NodeState.init) scaleset_id: Optional[UUID] = None tasks: Optional[List[Tuple[UUID, NodeTaskState]]] = None + version: str = Field(default="1.0.0") class ScalesetSummary(BaseModel): diff --git a/src/pytypes/onefuzztypes/requests.py b/src/pytypes/onefuzztypes/requests.py index 54a56cd8c..2eeac2ccb 100644 --- a/src/pytypes/onefuzztypes/requests.py +++ b/src/pytypes/onefuzztypes/requests.py @@ -197,3 +197,8 @@ class ReproGet(BaseRequest): class ProxyReset(BaseRequest): region: Region + + +class CanScheduleRequest(BaseRequest): + machine_id: UUID + task_id: UUID diff --git a/src/pytypes/onefuzztypes/responses.py b/src/pytypes/onefuzztypes/responses.py index 8f807b7c4..19acd5572 100644 --- a/src/pytypes/onefuzztypes/responses.py +++ b/src/pytypes/onefuzztypes/responses.py @@ -58,3 +58,8 @@ class AgentRegistration(BaseResponse): class PendingNodeCommand(BaseResponse): envelope: Optional[NodeCommandEnvelope] + + +class CanSchedule(BaseResponse): + allowed: bool + work_stopped: bool