From e308a4ae1e96cc39e5232a810846fa9f8d66f0d5 Mon Sep 17 00:00:00 2001 From: bmc-msft <41130664+bmc-msft@users.noreply.github.com> Date: Sat, 3 Oct 2020 02:43:04 -0400 Subject: [PATCH] refactor node state to fully put the agent in charge (#90) --- .../__app__/agent_can_schedule/__init__.py | 22 +- .../__app__/agent_events/__init__.py | 84 +++-- .../__app__/agent_registration/__init__.py | 22 +- src/api-service/__app__/node/__init__.py | 9 +- .../__app__/onefuzzlib/azure/queue.py | 22 ++ .../__app__/onefuzzlib/azure/vmss.py | 2 +- src/api-service/__app__/onefuzzlib/pools.py | 330 +++++++++++------- .../__app__/onefuzzlib/tasks/main.py | 6 +- src/api-service/__app__/requirements.txt | 2 +- src/api-service/__app__/scaleset/__init__.py | 3 +- .../__app__/scaleset_events/__init__.py | 19 +- src/pytypes/onefuzztypes/models.py | 3 +- 12 files changed, 326 insertions(+), 198 deletions(-) diff --git a/src/api-service/__app__/agent_can_schedule/__init__.py b/src/api-service/__app__/agent_can_schedule/__init__.py index 6ff4700c8..8f82f8e5c 100644 --- a/src/api-service/__app__/agent_can_schedule/__init__.py +++ b/src/api-service/__app__/agent_can_schedule/__init__.py @@ -3,16 +3,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import logging - import azure.functions as func from onefuzztypes.enums import ErrorCode, TaskState -from onefuzztypes.models import Error, NodeCommand, StopNodeCommand +from onefuzztypes.models import Error from onefuzztypes.requests import CanScheduleRequest from onefuzztypes.responses import CanSchedule from ..onefuzzlib.agent_authorization import verify_token -from ..onefuzzlib.pools import Node, NodeMessage +from ..onefuzzlib.pools import Node from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.tasks.main import Task @@ -31,23 +29,13 @@ def post(req: func.HttpRequest) -> func.HttpResponse: allowed = True work_stopped = False - if node.is_outdated(): - logging.info( - "received can_schedule request from outdated node '%s' version '%s'", - node.machine_id, - node.version, - ) + + if not node.can_process_new_work(): allowed = False - stop_message = NodeMessage( - agent_id=node.machine_id, - message=NodeCommand(stop=StopNodeCommand()), - ) - stop_message.save() task = Task.get_by_task_id(request.task_id) - work_stopped = isinstance(task, Error) or (task.state != TaskState.scheduled) - + work_stopped = isinstance(task, Error) or task.state in TaskState.shutting_down() if work_stopped: allowed = False diff --git a/src/api-service/__app__/agent_events/__init__.py b/src/api-service/__app__/agent_events/__init__.py index 49aedde94..388834b09 100644 --- a/src/api-service/__app__/agent_events/__init__.py +++ b/src/api-service/__app__/agent_events/__init__.py @@ -47,11 +47,28 @@ def get_node_checked(machine_id: UUID) -> Node: def on_state_update( machine_id: UUID, state_update: NodeStateUpdate, -) -> func.HttpResponse: +) -> None: state = state_update.state node = get_node_checked(machine_id) - if state == NodeState.init or node.state not in NodeState.ready_for_reset(): + if state == NodeState.free: + if node.reimage_requested or node.delete_requested: + logging.info("stopping free node with reset flags: %s", node.machine_id) + node.stop() + return + + if node.could_shrink_scaleset(): + logging.info("stopping free node to resize scaleset: %s", node.machine_id) + node.set_halt() + return + + if state == NodeState.init: + if node.delete_requested: + node.stop() + return + node.reimage_requested = False + node.save() + elif node.state not in NodeState.ready_for_reset(): if node.state != state: node.state = state node.save() @@ -91,26 +108,28 @@ def on_state_update( ) node_task.save() elif state == NodeState.done: + # if tasks are running on the node when it reports as Done + # those are stopped early + node.mark_tasks_stopped_early() + # Model-validated. # # This field will be required in the future. # For now, it is optional for back compat. done_data = cast(Optional[NodeDoneEventData], state_update.data) - if done_data: + # TODO: do something with this done data if done_data.error: logging.error( - "node `done` with error: machine_id = %s, data = %s", + "node 'done' with error: machine_id:%s, data:%s", machine_id, done_data, ) else: logging.info("ignoring state updates from the node: %s: %s", machine_id, state) - return ok(BoolResult(result=True)) - -def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse: +def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None: if event.running: task_id = event.running.task_id elif event.done: @@ -129,37 +148,32 @@ def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse: task.state = TaskState.running if node.state not in NodeState.ready_for_reset(): node.state = NodeState.busy + node.save() node_task.save() # Start the clock for the task if it wasn't started already # (as happens in 1.0.0 agents) task.on_start() elif event.done: - # Only record exit status if the task isn't already shutting down. - # - # It's ok for the agent to fail because resources vanish out from underneath - # it during deletion. - if task.state not in TaskState.shutting_down(): - exit_status = event.done.exit_status - - if not exit_status.success: - logging.error("task failed: status = %s", exit_status) - - task.mark_failed( - Error( - code=ErrorCode.TASK_FAILED, - errors=[ - "task failed. exit_status = %s" % exit_status, - event.done.stdout, - event.done.stderr, - ], - ) - ) - - task.state = TaskState.stopping - if node.state not in NodeState.ready_for_reset(): - node.state = NodeState.done node_task.delete() + + exit_status = event.done.exit_status + if not exit_status.success: + logging.error("task failed. status:%s", exit_status) + task.mark_failed( + Error( + code=ErrorCode.TASK_FAILED, + errors=[ + "task failed. exit_status:%s" % exit_status, + event.done.stdout, + event.done.stderr, + ], + ) + ) + else: + task.mark_stopping() + + node.to_reimage(done=True) else: err = Error( code=ErrorCode.INVALID_REQUEST, @@ -168,11 +182,9 @@ def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse: raise RequestException(err) task.save() - node.save() task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event) task_event.save() - return ok(BoolResult(result=True)) def post(req: func.HttpRequest) -> func.HttpResponse: @@ -197,9 +209,11 @@ def post(req: func.HttpRequest) -> func.HttpResponse: return not_ok(err, context=ERROR_CONTEXT) if event.state_update: - return on_state_update(envelope.machine_id, event.state_update) + on_state_update(envelope.machine_id, event.state_update) + return ok(BoolResult(result=True)) elif event.worker_event: - return on_worker_event(envelope.machine_id, event.worker_event) + on_worker_event(envelope.machine_id, event.worker_event) + return ok(BoolResult(result=True)) else: err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"]) return not_ok(err, context=ERROR_CONTEXT) diff --git a/src/api-service/__app__/agent_registration/__init__.py b/src/api-service/__app__/agent_registration/__init__.py index 41d13ccde..4f42d7c0f 100644 --- a/src/api-service/__app__/agent_registration/__init__.py +++ b/src/api-service/__app__/agent_registration/__init__.py @@ -6,7 +6,7 @@ from uuid import UUID import azure.functions as func -from onefuzztypes.enums import ErrorCode +from onefuzztypes.enums import ErrorCode, NodeState from onefuzztypes.models import Error from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost from onefuzztypes.responses import AgentRegistration @@ -78,7 +78,6 @@ def post(req: func.HttpRequest) -> func.HttpResponse: registration_request = parse_uri(AgentRegistrationPost, req) if isinstance(registration_request, Error): return not_ok(registration_request, context="agent registration") - agent_node = Node.get_by_machine_id(registration_request.machine_id) pool = Pool.get_by_name(registration_request.pool_name) if isinstance(pool, Error): @@ -90,20 +89,23 @@ def post(req: func.HttpRequest) -> func.HttpResponse: context="agent registration", ) - if agent_node is None: - agent_node = Node( + node = Node.get_by_machine_id(registration_request.machine_id) + if node: + if node.version != registration_request.version: + NodeMessage.clear_messages(node.machine_id) + node.version = registration_request.version + node.reimage_requested = False + node.state = NodeState.init + else: + node = Node( pool_name=registration_request.pool_name, machine_id=registration_request.machine_id, scaleset_id=registration_request.scaleset_id, version=registration_request.version, ) - agent_node.save() - elif agent_node.version.lower != registration_request.version: - NodeMessage.clear_messages(agent_node.machine_id) - agent_node.version = registration_request.version - agent_node.save() + node.save() - return create_registration_response(agent_node.machine_id, pool) + return create_registration_response(node.machine_id, pool) def main(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/node/__init__.py b/src/api-service/__app__/node/__init__.py index a33972abd..b6dd64476 100644 --- a/src/api-service/__app__/node/__init__.py +++ b/src/api-service/__app__/node/__init__.py @@ -4,7 +4,7 @@ # Licensed under the MIT License. import azure.functions as func -from onefuzztypes.enums import ErrorCode, NodeState +from onefuzztypes.enums import ErrorCode from onefuzztypes.models import Error from onefuzztypes.requests import NodeGet, NodeSearch from onefuzztypes.responses import BoolResult @@ -54,8 +54,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse: context=request.machine_id, ) - node.state = NodeState.halt - node.save() + node.set_halt() return ok(BoolResult(result=True)) @@ -72,9 +71,7 @@ def patch(req: func.HttpRequest) -> func.HttpResponse: context=request.machine_id, ) - node.state = NodeState.done - node.save() - + node.stop() return ok(BoolResult(result=True)) diff --git a/src/api-service/__app__/onefuzzlib/azure/queue.py b/src/api-service/__app__/onefuzzlib/azure/queue.py index bb7aa71f2..e1e1c7d97 100644 --- a/src/api-service/__app__/onefuzzlib/azure/queue.py +++ b/src/api-service/__app__/onefuzzlib/azure/queue.py @@ -88,6 +88,15 @@ def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceC return None +def clear_queue(name: QueueNameType, *, account_id: str) -> None: + queue = get_queue(name, account_id=account_id) + if queue: + try: + queue.clear_messages() + except ResourceNotFoundError: + return None + + def send_message( name: QueueNameType, message: bytes, @@ -102,6 +111,19 @@ def send_message( pass +def remove_first_message(name: QueueNameType, *, account_id: str) -> bool: + create_queue(name, account_id=account_id) + queue = get_queue(name, account_id=account_id) + if queue: + try: + for message in queue.receive_messages(): + queue.delete_message(message) + return True + except ResourceNotFoundError: + return False + return False + + A = TypeVar("A", bound=BaseModel) diff --git a/src/api-service/__app__/onefuzzlib/azure/vmss.py b/src/api-service/__app__/onefuzzlib/azure/vmss.py index d073e9078..7fce02ec9 100644 --- a/src/api-service/__app__/onefuzzlib/azure/vmss.py +++ b/src/api-service/__app__/onefuzzlib/azure/vmss.py @@ -76,7 +76,7 @@ def get_vmss_size(name: UUID) -> Optional[int]: def list_instance_ids(name: UUID) -> Dict[UUID, str]: - logging.info("get instance IDs for scaleset: %s", name) + logging.debug("get instance IDs for scaleset: %s", name) resource_group = get_base_resource_group() compute_client = mgmt_client_factory(ComputeManagementClient) diff --git a/src/api-service/__app__/onefuzzlib/pools.py b/src/api-service/__app__/onefuzzlib/pools.py index 4d103efc9..035adeaf8 100644 --- a/src/api-service/__app__/onefuzzlib/pools.py +++ b/src/api-service/__app__/onefuzzlib/pools.py @@ -6,7 +6,7 @@ import datetime import logging from typing import Dict, List, Optional, Tuple, Union -from uuid import UUID +from uuid import UUID, uuid4 from onefuzztypes.enums import ( OS, @@ -15,7 +15,6 @@ from onefuzztypes.enums import ( NodeState, PoolState, ScalesetState, - TaskState, ) from onefuzztypes.models import Error from onefuzztypes.models import Node as BASE_NODE @@ -32,14 +31,21 @@ from onefuzztypes.models import ( WorkUnitSummary, ) from onefuzztypes.primitives import PoolName, Region -from pydantic import Field +from pydantic import BaseModel, Field from .__version__ import __version__ from .azure.auth import build_auth -from .azure.creds import get_fuzz_storage +from .azure.creds import get_func_storage, get_fuzz_storage from .azure.image import get_os from .azure.network import Network -from .azure.queue import create_queue, delete_queue, peek_queue, queue_object +from .azure.queue import ( + clear_queue, + create_queue, + delete_queue, + peek_queue, + queue_object, + remove_first_message, +) from .azure.vmss import ( UnableToUpdate, create_vmss, @@ -99,7 +105,7 @@ class Node(BASE_NODE, ORMMixin): # azure table query always return false when the column does not exist # We write the query this way to allow us to get the nodes where the # version is not defined as well as the nodes with a mismatched version - version_query = "not (version ne '%s')" % __version__ + version_query = "not (version eq '%s')" % __version__ return cls.search(query=query, raw_unchecked_filter=version_query) @classmethod @@ -154,16 +160,90 @@ class Node(BASE_NODE, ORMMixin): for node in nodes: if node.state not in NodeState.ready_for_reset(): logging.info( - "stopping task %s on machine_id:%s", - task_id, + "stopping machine_id:%s running task:%s", node.machine_id, + task_id, ) - node.state = NodeState.done - node.save() + node.stop() + + def mark_tasks_stopped_early(self) -> None: + from .tasks.main import Task + + for entry in NodeTasks.get_by_machine_id(self.machine_id): + task = Task.get_by_task_id(entry.task_id) + if isinstance(task, Task): + task.mark_failed( + Error( + code=ErrorCode.TASK_FAILED, + errors=["node reimaged during task execution"], + ) + ) + entry.delete() + + def could_shrink_scaleset(self) -> bool: + if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink(): + return True + return False + + def can_process_new_work(self) -> bool: + if self.is_outdated(): + logging.info( + "can_schedule old version machine_id:%s version:%s", + self.machine_id, + self.version, + ) + self.stop() + return False + + if self.delete_requested or self.reimage_requested: + logging.info( + "can_schedule should be recycled. machine_id:%s", self.machine_id + ) + self.stop() + return False + + if self.could_shrink_scaleset(): + self.set_halt() + logging.info("node scheduled to shrink. machine_id:%s", self.machine_id) + return False + + return True def is_outdated(self) -> bool: return self.version != __version__ + def send_message(self, message: NodeCommand) -> None: + stop_message = NodeMessage( + agent_id=self.machine_id, + message=message, + ) + stop_message.save() + + def to_reimage(self, done: bool = False) -> None: + if done: + if self.state not in NodeState.ready_for_reset(): + self.state = NodeState.done + + if not self.reimage_requested and not self.delete_requested: + logging.info("setting reimage_requested: %s", self.machine_id) + self.reimage_requested = True + self.save() + + def stop(self) -> None: + self.to_reimage() + self.send_message(NodeCommand(stop=StopNodeCommand())) + + def set_shutdown(self) -> None: + # don't give out more work to the node, but let it finish existing work + logging.info("setting delete_requested: %s", self.machine_id) + self.delete_requested = True + self.save() + + def set_halt(self) -> None: + """ Tell the node to stop everything. """ + self.set_shutdown() + self.stop() + class NodeTasks(BASE_NODE_TASK, ORMMixin): @classmethod @@ -384,8 +464,7 @@ class Pool(BASE_POOL, ORMMixin): scaleset.save() for node in nodes: - node.state = NodeState.shutdown - node.save() + node.set_shutdown() self.save() @@ -405,13 +484,7 @@ class Pool(BASE_POOL, ORMMixin): scaleset.save() for node in nodes: - logging.info( - "deleting node from pool: %s (%s) - machine_id:%s", - self.pool_id, - self.name, - node.machine_id, - ) - node.delete() + node.set_halt() self.save() @@ -496,6 +569,8 @@ class Scaleset(BASE_SCALESET, ORMMixin): def init(self) -> None: logging.info("scaleset init: %s", self.scaleset_id) + ScalesetShrinkQueue(self.scaleset_id).create() + # Handle the race condition between a pool being deleted and a # scaleset being added to the pool. pool = Pool.get_by_name(self.pool_name) @@ -596,49 +671,62 @@ class Scaleset(BASE_SCALESET, ORMMixin): # result = 'did I modify the scaleset in azure' def cleanup_nodes(self) -> bool: if self.state == ScalesetState.halt: + logging.info("halting scaleset: %s", self.scaleset_id) self.halt() return True + to_reimage = [] + to_delete = [] + + outdated = Node.search_outdated(scaleset_id=self.scaleset_id) + for node in outdated: + logging.info( + "node is outdated: %s - node_version:%s api_version:%s", + node.machine_id, + node.version, + __version__, + ) + if node.version == "1.0.0": + node.state = NodeState.done + to_reimage.append(node) + else: + node.to_reimage() + nodes = Node.search_states( scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset() ) - outdated = Node.search_outdated( - scaleset_id=self.scaleset_id, - states=[NodeState.free], - ) - - if not (nodes or outdated): - logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id) + if not outdated and not nodes: + logging.info("no nodes need updating: %s", self.scaleset_id) return False - to_delete = [] - to_reimage = [] - - for node in outdated: - if node.version == "1.0.0": - to_reimage.append(node) - else: - stop_message = NodeMessage( - agent_id=node.machine_id, - message=NodeCommand(stop=StopNodeCommand()), - ) - stop_message.save() + # ground truth of existing nodes + azure_nodes = list_instance_ids(self.scaleset_id) for node in nodes: - # delete nodes that are not waiting on the scaleset GC - if not node.scaleset_node_exists(): + if node.machine_id not in azure_nodes: + logging.info( + "no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id + ) node.delete() - elif node.state in [NodeState.shutdown, NodeState.halt]: + elif node.delete_requested: to_delete.append(node) else: - to_reimage.append(node) + if ScalesetShrinkQueue(self.scaleset_id).should_shrink(): + node.set_halt() + to_delete.append(node) + else: + to_reimage.append(node) # Perform operations until they fail due to scaleset getting locked try: if to_delete: + logging.info( + "deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete) + ) self.delete_nodes(to_delete) for node in to_delete: + node.set_halt() node.state = NodeState.halt node.save() @@ -646,69 +734,65 @@ class Scaleset(BASE_SCALESET, ORMMixin): self.reimage_nodes(to_reimage) except UnableToUpdate: logging.info("scaleset update already in progress: %s", self.scaleset_id) + return True - def resize(self) -> None: - logging.info( - "scaleset resize: %s - current: %s new: %s", - self.scaleset_id, - self.size, - self.new_size, - ) - - # no work needed to resize - if self.new_size is None: + def _resize_equal(self) -> None: + # NOTE: this is the only place we reset to the 'running' state. + # This ensures that our idea of scaleset size agrees with Azure + node_count = len(Node.search_states(scaleset_id=self.scaleset_id)) + if node_count == self.size: + logging.info("resize finished: %s", self.scaleset_id) self.state = ScalesetState.running self.save() return + else: + logging.info( + "resize is finished, waiting for nodes to check in: " + "%s (%d of %d nodes checked in)", + self.scaleset_id, + node_count, + self.size, + ) + return + + def _resize_grow(self) -> None: + try: + resize_vmss(self.scaleset_id, self.size) + except UnableToUpdate: + logging.info("scaleset is mid-operation already") + return + + def _resize_shrink(self, to_remove: int) -> None: + queue = ScalesetShrinkQueue(self.scaleset_id) + for _ in range(to_remove): + queue.add_entry() + + def resize(self) -> None: + # no longer needing to resize + if self.state != ScalesetState.resize: + return + + logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size) + + # reset the node delete queue + ScalesetShrinkQueue(self.scaleset_id).clear() # just in case, always ensure size is within max capacity - self.new_size = min(self.new_size, self.max_size()) + self.size = min(self.size, self.max_size()) # Treat Azure knowledge of the size of the scaleset as "ground truth" size = get_vmss_size(self.scaleset_id) if size is None: - logging.info("scaleset is unavailable. Re-queuing") - self.save() + logging.info("scaleset is unavailable: %s", self.scaleset_id) return - if size == self.new_size: - # NOTE: this is the only place we reset to the 'running' state. - # This ensures that our idea of scaleset size agrees with Azure - node_count = len(Node.search_states(scaleset_id=self.scaleset_id)) - if node_count == self.size: - logging.info("resize finished: %s", self.scaleset_id) - self.new_size = None - self.state = ScalesetState.running - else: - logging.info( - "resize is finished, waiting for nodes to check in: " - "%s (%d of %d nodes checked in)", - self.scaleset_id, - node_count, - self.size, - ) - # When adding capacity, call the resize API directly - elif self.new_size > self.size: - try: - resize_vmss(self.scaleset_id, self.new_size) - except UnableToUpdate: - logging.info("scaleset is mid-operation already") - # Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node - # to pick up that the scaleset is too big upon task completion + if size == self.size: + self._resize_equal() + elif self.size > size: + self._resize_grow() else: - nodes = Node.search_states( - scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free] - ) - for node in nodes: - if size > self.new_size: - node.state = NodeState.halt - node.save() - size -= 1 - else: - break - - self.save() + self._resize_shrink(size - self.size) def delete_nodes(self, nodes: List[Node]) -> None: if not nodes: @@ -723,31 +807,12 @@ class Scaleset(BASE_SCALESET, ORMMixin): logging.info("deleting %s:%s", self.scaleset_id, machine_ids) delete_vmss_nodes(self.scaleset_id, machine_ids) - self.size -= len(machine_ids) - self.save() def reimage_nodes(self, nodes: List[Node]) -> None: - from .tasks.main import Task - if not nodes: logging.debug("no nodes to reimage") return - for node in nodes: - for entry in NodeTasks.get_by_machine_id(node.machine_id): - task = Task.get_by_task_id(entry.task_id) - if isinstance(task, Task): - if task.state in [TaskState.stopping, TaskState.stopped]: - continue - - task.mark_failed( - Error( - code=ErrorCode.TASK_FAILED, - errors=["node reimaged during task execution"], - ) - ) - entry.delete() - if self.state == ScalesetState.shutdown: self.delete_nodes(nodes) return @@ -766,28 +831,28 @@ class Scaleset(BASE_SCALESET, ORMMixin): ) def shutdown(self) -> None: - logging.info("scaleset shutdown: %s", self.scaleset_id) size = get_vmss_size(self.scaleset_id) + logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size) if size is None or size == 0: - self.state = ScalesetState.halt self.halt() - return - self.save() def halt(self) -> None: + self.state = ScalesetState.halt + ScalesetShrinkQueue(self.scaleset_id).delete() + for node in Node.search_states(scaleset_id=self.scaleset_id): logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id) node.delete() vmss = get_vmss(self.scaleset_id) - if vmss is None: - logging.info("scaleset deleted: %s", self.scaleset_id) - self.state = ScalesetState.halt - self.delete() - else: + if vmss: logging.info("scaleset deleting: %s", self.scaleset_id) delete_vmss(self.scaleset_id) self.save() + else: + logging.info("scaleset deleted: %s", self.scaleset_id) + self.state = ScalesetState.halt + self.delete() def max_size(self) -> int: # https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/ @@ -858,3 +923,30 @@ class Scaleset(BASE_SCALESET, ORMMixin): @classmethod def key_fields(cls) -> Tuple[str, str]: return ("pool_name", "scaleset_id") + + +class ShrinkEntry(BaseModel): + shrink_id: UUID = Field(default_factory=uuid4) + + +class ScalesetShrinkQueue: + def __init__(self, scaleset_id: UUID): + self.scaleset_id = scaleset_id + + def queue_name(self) -> str: + return "to-shrink-%s" % self.scaleset_id.hex + + def clear(self) -> None: + clear_queue(self.queue_name(), account_id=get_func_storage()) + + def create(self) -> None: + create_queue(self.queue_name(), account_id=get_func_storage()) + + def delete(self) -> None: + delete_queue(self.queue_name(), account_id=get_func_storage()) + + def add_entry(self) -> None: + queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage()) + + def should_shrink(self) -> bool: + return remove_first_message(self.queue_name(), account_id=get_func_storage()) diff --git a/src/api-service/__app__/onefuzzlib/tasks/main.py b/src/api-service/__app__/onefuzzlib/tasks/main.py index eb299dafa..37c394fe7 100644 --- a/src/api-service/__app__/onefuzzlib/tasks/main.py +++ b/src/api-service/__app__/onefuzzlib/tasks/main.py @@ -108,7 +108,6 @@ class Task(BASE_TASK, ORMMixin): self.save() def stopping(self) -> None: - # TODO: we need to tell every node currently working on this task to stop # TODO: we need to 'unschedule' this task from the existing pools self.state = TaskState.stopping @@ -154,6 +153,11 @@ class Task(BASE_TASK, ORMMixin): task = tasks[0] return task + def mark_stopping(self) -> None: + if self.state not in [TaskState.stopped, TaskState.stopping]: + self.state = TaskState.stopping + self.save() + def mark_failed(self, error: Error) -> None: if self.state in [TaskState.stopped, TaskState.stopping]: logging.debug( diff --git a/src/api-service/__app__/requirements.txt b/src/api-service/__app__/requirements.txt index dc898d35f..108c5d2ea 100644 --- a/src/api-service/__app__/requirements.txt +++ b/src/api-service/__app__/requirements.txt @@ -31,5 +31,5 @@ pydantic~=1.6.1 PyJWT~=1.7.1 requests~=2.24.0 memoization~=0.3.1 -# onefuzztypes version is set during build +# onefuzz types version is set during build onefuzztypes==0.0.0 diff --git a/src/api-service/__app__/scaleset/__init__.py b/src/api-service/__app__/scaleset/__init__.py index ea9fd19e4..c727652cd 100644 --- a/src/api-service/__app__/scaleset/__init__.py +++ b/src/api-service/__app__/scaleset/__init__.py @@ -109,6 +109,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse: scaleset.state = ScalesetState.halt else: scaleset.state = ScalesetState.shutdown + scaleset.save() scaleset.auth = None return ok(scaleset) @@ -133,7 +134,7 @@ def patch(req: func.HttpRequest) -> func.HttpResponse: ) if request.size is not None: - scaleset.new_size = request.size + scaleset.size = request.size scaleset.state = ScalesetState.resize scaleset.save() diff --git a/src/api-service/__app__/scaleset_events/__init__.py b/src/api-service/__app__/scaleset_events/__init__.py index 6b69abf6d..0513de4a3 100644 --- a/src/api-service/__app__/scaleset_events/__init__.py +++ b/src/api-service/__app__/scaleset_events/__init__.py @@ -13,21 +13,28 @@ from ..onefuzzlib.pools import Scaleset def process_scaleset(scaleset: Scaleset) -> None: - if scaleset.state == ScalesetState.halt: - scaleset.halt() - return + logging.debug("checking scaleset for updates: %s", scaleset.scaleset_id) + + if scaleset.state == ScalesetState.resize: + scaleset.resize() # if the scaleset is touched during cleanup, don't continue to process it if scaleset.cleanup_nodes(): + logging.debug("scaleset needed cleanup: %s", scaleset.scaleset_id) return - if scaleset.state in ScalesetState.needs_work(): + if ( + scaleset.state in ScalesetState.needs_work() + and scaleset.state != ScalesetState.resize + ): logging.info( "exec scaleset state: %s - %s", scaleset.scaleset_id, - scaleset.state.name, + scaleset.state, ) - getattr(scaleset, scaleset.state.name)() + + if hasattr(scaleset, scaleset.state.name): + getattr(scaleset, scaleset.state.name)() return diff --git a/src/pytypes/onefuzztypes/models.py b/src/pytypes/onefuzztypes/models.py index 55192dc18..42b372f19 100644 --- a/src/pytypes/onefuzztypes/models.py +++ b/src/pytypes/onefuzztypes/models.py @@ -399,6 +399,8 @@ class Node(BaseModel): scaleset_id: Optional[UUID] = None tasks: Optional[List[Tuple[UUID, NodeTaskState]]] = None version: str = Field(default="1.0.0") + reimage_requested: bool = Field(default=False) + delete_requested: bool = Field(default=False) class ScalesetSummary(BaseModel): @@ -447,7 +449,6 @@ class Scaleset(BaseModel): image: str region: Region size: int - new_size: Optional[int] spot_instances: bool error: Optional[Error] nodes: Optional[List[ScalesetNodeState]]