diff --git a/src/api-service/__app__/agent_can_schedule/__init__.py b/src/api-service/__app__/agent_can_schedule/__init__.py index 36e12d3f7..d3aba1530 100644 --- a/src/api-service/__app__/agent_can_schedule/__init__.py +++ b/src/api-service/__app__/agent_can_schedule/__init__.py @@ -11,9 +11,9 @@ from onefuzztypes.responses import CanSchedule from ..onefuzzlib.endpoint_authorization import call_if_agent from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Node from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.tasks.main import Task +from ..onefuzzlib.workers.nodes import Node def post(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/agent_commands/__init__.py b/src/api-service/__app__/agent_commands/__init__.py index 89af2de27..b0e3fe7dd 100644 --- a/src/api-service/__app__/agent_commands/__init__.py +++ b/src/api-service/__app__/agent_commands/__init__.py @@ -10,8 +10,8 @@ from onefuzztypes.responses import BoolResult, PendingNodeCommand from ..onefuzzlib.endpoint_authorization import call_if_agent from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import NodeMessage from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.nodes import NodeMessage def get(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/agent_registration/__init__.py b/src/api-service/__app__/agent_registration/__init__.py index 46b3eaf60..ed4803c99 100644 --- a/src/api-service/__app__/agent_registration/__init__.py +++ b/src/api-service/__app__/agent_registration/__init__.py @@ -18,8 +18,9 @@ from ..onefuzzlib.azure.queue import get_queue_sas from ..onefuzzlib.azure.storage import StorageType from ..onefuzzlib.endpoint_authorization import call_if_agent from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Node, Pool from ..onefuzzlib.request import not_ok, ok, parse_uri +from ..onefuzzlib.workers.nodes import Node +from ..onefuzzlib.workers.pools import Pool def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpResponse: diff --git a/src/api-service/__app__/node/__init__.py b/src/api-service/__app__/node/__init__.py index e6beef9e8..c58c92903 100644 --- a/src/api-service/__app__/node/__init__.py +++ b/src/api-service/__app__/node/__init__.py @@ -11,8 +11,8 @@ from onefuzztypes.responses import BoolResult from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Node, NodeTasks from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.nodes import Node, NodeTasks def get(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/node_add_ssh_key/__init__.py b/src/api-service/__app__/node_add_ssh_key/__init__.py index fb1aaf055..f06d57354 100644 --- a/src/api-service/__app__/node_add_ssh_key/__init__.py +++ b/src/api-service/__app__/node_add_ssh_key/__init__.py @@ -11,8 +11,8 @@ from onefuzztypes.responses import BoolResult from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Node from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.nodes import Node def post(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/onefuzzlib/agent_events.py b/src/api-service/__app__/onefuzzlib/agent_events.py index 5bbfe8dd4..4c6e67ffd 100644 --- a/src/api-service/__app__/onefuzzlib/agent_events.py +++ b/src/api-service/__app__/onefuzzlib/agent_events.py @@ -25,9 +25,9 @@ from onefuzztypes.models import ( WorkerRunningEvent, ) -from ..onefuzzlib.pools import Node, NodeTasks from ..onefuzzlib.task_event import TaskEvent from ..onefuzzlib.tasks.main import Task +from ..onefuzzlib.workers.nodes import Node, NodeTasks MAX_OUTPUT_SIZE = 4096 diff --git a/src/api-service/__app__/onefuzzlib/autoscale.py b/src/api-service/__app__/onefuzzlib/autoscale.py index a31da3574..253878d61 100644 --- a/src/api-service/__app__/onefuzzlib/autoscale.py +++ b/src/api-service/__app__/onefuzzlib/autoscale.py @@ -10,8 +10,10 @@ from typing import List from onefuzztypes.enums import NodeState, ScalesetState from onefuzztypes.models import AutoScaleConfig, TaskPool -from .pools import Node, Pool, Scaleset from .tasks.main import Task +from .workers.nodes import Node +from .workers.pools import Pool +from .workers.scalesets import Scaleset def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None: diff --git a/src/api-service/__app__/onefuzzlib/endpoint_authorization.py b/src/api-service/__app__/onefuzzlib/endpoint_authorization.py index 32d5d7b35..5ca9534da 100644 --- a/src/api-service/__app__/onefuzzlib/endpoint_authorization.py +++ b/src/api-service/__app__/onefuzzlib/endpoint_authorization.py @@ -13,9 +13,10 @@ from onefuzztypes.enums import ErrorCode from onefuzztypes.models import Error, UserInfo from .azure.creds import get_scaleset_principal_id -from .pools import Pool, Scaleset from .request import not_ok from .user_credentials import parse_jwt_token +from .workers.pools import Pool +from .workers.scalesets import Scaleset @cached(ttl=60) diff --git a/src/api-service/__app__/onefuzzlib/pools.py b/src/api-service/__app__/onefuzzlib/pools.py deleted file mode 100644 index d1d48d322..000000000 --- a/src/api-service/__app__/onefuzzlib/pools.py +++ /dev/null @@ -1,1210 +0,0 @@ -#!/usr/bin/env python -# -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import datetime -import logging -from typing import Any, Dict, List, Optional, Tuple, Union -from uuid import UUID, uuid4 - -from onefuzztypes.enums import ( - OS, - Architecture, - ErrorCode, - NodeState, - PoolState, - ScalesetState, -) -from onefuzztypes.events import ( - EventNodeCreated, - EventNodeDeleted, - EventNodeStateUpdated, - EventPoolCreated, - EventPoolDeleted, - EventScalesetCreated, - EventScalesetDeleted, - EventScalesetFailed, -) -from onefuzztypes.models import AutoScaleConfig, Error -from onefuzztypes.models import Node as BASE_NODE -from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey -from onefuzztypes.models import NodeTasks as BASE_NODE_TASK -from onefuzztypes.models import Pool as BASE_POOL -from onefuzztypes.models import Result -from onefuzztypes.models import Scaleset as BASE_SCALESET -from onefuzztypes.models import ( - ScalesetNodeState, - ScalesetSummary, - StopNodeCommand, - WorkSet, - WorkSetSummary, - WorkUnitSummary, -) -from onefuzztypes.primitives import PoolName, Region -from pydantic import BaseModel, Field - -from .__version__ import __version__ -from .azure.auth import build_auth -from .azure.image import get_os -from .azure.network import Network -from .azure.queue import ( - clear_queue, - create_queue, - delete_queue, - peek_queue, - queue_object, - remove_first_message, -) -from .azure.storage import StorageType -from .azure.vmss import ( - UnableToUpdate, - create_vmss, - delete_vmss, - delete_vmss_nodes, - get_instance_id, - get_vmss, - get_vmss_size, - list_instance_ids, - reimage_vmss_nodes, - resize_vmss, - update_extensions, -) -from .events import send_event -from .extension import fuzz_extensions -from .orm import MappingIntStrAny, ORMMixin, QueryFilter - -NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) -NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) - -# Future work: -# -# Enabling autoscaling for the scalesets based on the pool work queues. -# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics - - -class Node(BASE_NODE, ORMMixin): - # should only be set by Scaleset.reimage_nodes - # should only be unset during agent_registration POST - reimage_queued: bool = Field(default=False) - - @classmethod - def create( - cls, - *, - pool_name: PoolName, - machine_id: UUID, - scaleset_id: Optional[UUID], - version: str, - ) -> "Node": - node = cls( - pool_name=pool_name, - machine_id=machine_id, - scaleset_id=scaleset_id, - version=version, - ) - node.save() - send_event( - EventNodeCreated( - machine_id=node.machine_id, - scaleset_id=node.scaleset_id, - pool_name=node.pool_name, - ) - ) - return node - - @classmethod - def search_states( - cls, - *, - scaleset_id: Optional[UUID] = None, - states: Optional[List[NodeState]] = None, - pool_name: Optional[str] = None, - ) -> List["Node"]: - query: QueryFilter = {} - if scaleset_id: - query["scaleset_id"] = [scaleset_id] - if states: - query["state"] = states - if pool_name: - query["pool_name"] = [pool_name] - return cls.search(query=query) - - @classmethod - def search_outdated( - cls, - *, - scaleset_id: Optional[UUID] = None, - states: Optional[List[NodeState]] = None, - pool_name: Optional[str] = None, - exclude_update_scheduled: bool = False, - num_results: Optional[int] = None, - ) -> List["Node"]: - query: QueryFilter = {} - if scaleset_id: - query["scaleset_id"] = [scaleset_id] - if states: - query["state"] = states - if pool_name: - query["pool_name"] = [pool_name] - - if exclude_update_scheduled: - query["reimage_requested"] = [False] - query["delete_requested"] = [False] - - # azure table query always return false when the column does not exist - # We write the query this way to allow us to get the nodes where the - # version is not defined as well as the nodes with a mismatched version - version_query = "not (version eq '%s')" % __version__ - return cls.search( - query=query, raw_unchecked_filter=version_query, num_results=num_results - ) - - @classmethod - def mark_outdated_nodes(cls) -> None: - # ony update 500 nodes at a time to mitigate timeout issues - outdated = cls.search_outdated(exclude_update_scheduled=True, num_results=500) - for node in outdated: - logging.info( - "node is outdated: %s - node_version:%s api_version:%s", - node.machine_id, - node.version, - __version__, - ) - if node.version == "1.0.0": - node.to_reimage(done=True) - else: - node.to_reimage() - - @classmethod - def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]: - nodes = cls.search(query={"machine_id": [machine_id]}) - if not nodes: - return None - - if len(nodes) != 1: - return None - return nodes[0] - - @classmethod - def key_fields(cls) -> Tuple[str, str]: - return ("pool_name", "machine_id") - - def save_exclude(self) -> Optional[MappingIntStrAny]: - return {"tasks": ...} - - def telemetry_include(self) -> Optional[MappingIntStrAny]: - return { - "machine_id": ..., - "state": ..., - "scaleset_id": ..., - } - - def scaleset_node_exists(self) -> bool: - if self.scaleset_id is None: - return False - - scaleset = Scaleset.get_by_id(self.scaleset_id) - if not isinstance(scaleset, Scaleset): - return False - - instance_id = get_instance_id(scaleset.scaleset_id, self.machine_id) - return isinstance(instance_id, str) - - @classmethod - def stop_task(cls, task_id: UUID) -> None: - # For now, this just re-images the node. Eventually, this - # should send a message to the node to let the agent shut down - # gracefully - nodes = NodeTasks.get_nodes_by_task_id(task_id) - for node in nodes: - if node.state not in NodeState.ready_for_reset(): - logging.info( - "stopping machine_id:%s running task:%s", - node.machine_id, - task_id, - ) - node.stop() - - def mark_tasks_stopped_early(self) -> None: - from .tasks.main import Task - - for entry in NodeTasks.get_by_machine_id(self.machine_id): - task = Task.get_by_task_id(entry.task_id) - if isinstance(task, Task): - task.mark_failed( - Error( - code=ErrorCode.TASK_FAILED, - errors=["node reimaged during task execution"], - ) - ) - - def could_shrink_scaleset(self) -> bool: - if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink(): - return True - return False - - def can_process_new_work(self) -> bool: - if self.is_outdated(): - logging.info( - "can_schedule old version machine_id:%s version:%s", - self.machine_id, - self.version, - ) - self.stop() - return False - - if self.state in NodeState.ready_for_reset(): - logging.info( - "can_schedule node is set for reset. machine_id:%s", self.machine_id - ) - return False - - if self.delete_requested: - logging.info( - "can_schedule is set to be deleted. machine_id:%s", - self.machine_id, - ) - self.stop() - return False - - if self.reimage_requested: - logging.info( - "can_schedule is set to be reimaged. machine_id:%s", - self.machine_id, - ) - self.stop() - return False - - if self.could_shrink_scaleset(): - self.set_halt() - logging.info("node scheduled to shrink. machine_id:%s", self.machine_id) - return False - - return True - - def is_outdated(self) -> bool: - return self.version != __version__ - - def send_message(self, message: NodeCommand) -> None: - NodeMessage( - machine_id=self.machine_id, - message=message, - ).save() - - def to_reimage(self, done: bool = False) -> None: - if done: - if self.state not in NodeState.ready_for_reset(): - self.state = NodeState.done - - if not self.reimage_requested and not self.delete_requested: - logging.info("setting reimage_requested: %s", self.machine_id) - self.reimage_requested = True - self.save() - - def add_ssh_public_key(self, public_key: str) -> Result[None]: - if self.scaleset_id is None: - return Error( - code=ErrorCode.INVALID_REQUEST, - errors=["only able to add ssh keys to scaleset nodes"], - ) - - if not public_key.endswith("\n"): - public_key += "\n" - - self.send_message( - NodeCommand(add_ssh_key=NodeCommandAddSshKey(public_key=public_key)) - ) - return None - - def stop(self) -> None: - self.to_reimage() - self.send_message(NodeCommand(stop=StopNodeCommand())) - - def set_shutdown(self) -> None: - # don't give out more work to the node, but let it finish existing work - logging.info("setting delete_requested: %s", self.machine_id) - self.delete_requested = True - self.save() - - def set_halt(self) -> None: - """ Tell the node to stop everything. """ - self.set_shutdown() - self.stop() - self.set_state(NodeState.halt) - - @classmethod - def get_dead_nodes( - cls, scaleset_id: UUID, expiration_period: datetime.timedelta - ) -> List["Node"]: - time_filter = "heartbeat lt datetime'%s'" % ( - (datetime.datetime.utcnow() - expiration_period).isoformat() - ) - return cls.search( - query={"scaleset_id": [scaleset_id]}, - raw_unchecked_filter=time_filter, - ) - - @classmethod - def reimage_long_lived_nodes(cls, scaleset_id: UUID) -> None: - """ - Mark any excessively long lived node to be re-imaged. - - This helps keep nodes on scalesets that use `latest` OS image SKUs - reasonably up-to-date with OS patches without disrupting running - fuzzing tasks with patch reboot cycles. - """ - time_filter = "Timestamp lt datetime'%s'" % ( - (datetime.datetime.utcnow() - NODE_REIMAGE_TIME).isoformat() - ) - # skip any nodes already marked for reimage/deletion - for node in cls.search( - query={ - "scaleset_id": [scaleset_id], - "reimage_requested": [False], - "delete_requested": [False], - }, - raw_unchecked_filter=time_filter, - ): - node.to_reimage() - - def set_state(self, state: NodeState) -> None: - if self.state != state: - self.state = state - send_event( - EventNodeStateUpdated( - machine_id=self.machine_id, - pool_name=self.pool_name, - scaleset_id=self.scaleset_id, - state=state, - ) - ) - - self.save() - - def delete(self) -> None: - self.mark_tasks_stopped_early() - NodeTasks.clear_by_machine_id(self.machine_id) - NodeMessage.clear_messages(self.machine_id) - super().delete() - send_event( - EventNodeDeleted( - machine_id=self.machine_id, - pool_name=self.pool_name, - scaleset_id=self.scaleset_id, - ) - ) - - -class NodeTasks(BASE_NODE_TASK, ORMMixin): - @classmethod - def key_fields(cls) -> Tuple[str, str]: - return ("machine_id", "task_id") - - def telemetry_include(self) -> Optional[MappingIntStrAny]: - return { - "machine_id": ..., - "task_id": ..., - "state": ..., - } - - @classmethod - def get_nodes_by_task_id(cls, task_id: UUID) -> List["Node"]: - result = [] - for entry in cls.search(query={"task_id": [task_id]}): - node = Node.get_by_machine_id(entry.machine_id) - if node: - result.append(node) - return result - - @classmethod - def get_node_assignments(cls, task_id: UUID) -> List[NodeAssignment]: - result = [] - for entry in cls.search(query={"task_id": [task_id]}): - node = Node.get_by_machine_id(entry.machine_id) - if node: - node_assignment = NodeAssignment( - node_id=node.machine_id, - scaleset_id=node.scaleset_id, - state=entry.state, - ) - result.append(node_assignment) - - return result - - @classmethod - def get_by_machine_id(cls, machine_id: UUID) -> List["NodeTasks"]: - return cls.search(query={"machine_id": [machine_id]}) - - @classmethod - def get_by_task_id(cls, task_id: UUID) -> List["NodeTasks"]: - return cls.search(query={"task_id": [task_id]}) - - @classmethod - def clear_by_machine_id(cls, machine_id: UUID) -> None: - logging.info("clearing tasks for node: %s", machine_id) - for entry in cls.get_by_machine_id(machine_id): - entry.delete() - - -# this isn't anticipated to be needed by the client, hence it not -# being in onefuzztypes -class NodeMessage(ORMMixin): - machine_id: UUID - message_id: str = Field(default_factory=datetime.datetime.utcnow().timestamp) - message: NodeCommand - - @classmethod - def key_fields(cls) -> Tuple[str, str]: - return ("machine_id", "message_id") - - @classmethod - def get_messages( - cls, machine_id: UUID, num_results: int = None - ) -> List["NodeMessage"]: - entries: List["NodeMessage"] = cls.search( - query={"machine_id": [machine_id]}, num_results=num_results - ) - return entries - - @classmethod - def clear_messages(cls, machine_id: UUID) -> None: - logging.info("clearing messages for node: %s", machine_id) - messages = cls.get_messages(machine_id) - for message in messages: - message.delete() - - -class Pool(BASE_POOL, ORMMixin): - @classmethod - def create( - cls, - *, - name: PoolName, - os: OS, - arch: Architecture, - managed: bool, - client_id: Optional[UUID], - autoscale: Optional[AutoScaleConfig], - ) -> "Pool": - entry = cls( - name=name, - os=os, - arch=arch, - managed=managed, - client_id=client_id, - config=None, - autoscale=autoscale, - ) - entry.save() - send_event( - EventPoolCreated( - pool_name=name, - os=os, - arch=arch, - managed=managed, - autoscale=autoscale, - ) - ) - return entry - - def save_exclude(self) -> Optional[MappingIntStrAny]: - return { - "nodes": ..., - "queue": ..., - "work_queue": ..., - "config": ..., - "node_summary": ..., - } - - def export_exclude(self) -> Optional[MappingIntStrAny]: - return { - "etag": ..., - "timestamp": ..., - } - - def telemetry_include(self) -> Optional[MappingIntStrAny]: - return { - "pool_id": ..., - "os": ..., - "state": ..., - "managed": ..., - } - - def populate_scaleset_summary(self) -> None: - self.scaleset_summary = [ - ScalesetSummary(scaleset_id=x.scaleset_id, state=x.state) - for x in Scaleset.search_by_pool(self.name) - ] - - def populate_work_queue(self) -> None: - self.work_queue = [] - - # Only populate the work queue summaries if the pool is initialized. We - # can then be sure that the queue is available in the operations below. - if self.state == PoolState.init: - return - - worksets = peek_queue( - self.get_pool_queue(), StorageType.corpus, object_type=WorkSet - ) - - for workset in worksets: - work_units = [ - WorkUnitSummary( - job_id=work_unit.job_id, - task_id=work_unit.task_id, - task_type=work_unit.task_type, - ) - for work_unit in workset.work_units - ] - self.work_queue.append(WorkSetSummary(work_units=work_units)) - - def get_pool_queue(self) -> str: - return "pool-%s" % self.pool_id.hex - - def init(self) -> None: - create_queue(self.get_pool_queue(), StorageType.corpus) - self.state = PoolState.running - self.save() - - def schedule_workset(self, work_set: WorkSet) -> bool: - # Don't schedule work for pools that can't and won't do work. - if self.state in [PoolState.shutdown, PoolState.halt]: - return False - - return queue_object( - self.get_pool_queue(), - work_set, - StorageType.corpus, - ) - - @classmethod - def get_by_id(cls, pool_id: UUID) -> Union[Error, "Pool"]: - pools = cls.search(query={"pool_id": [pool_id]}) - if not pools: - return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"]) - - if len(pools) != 1: - return Error( - code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"] - ) - pool = pools[0] - return pool - - @classmethod - def get_by_name(cls, name: PoolName) -> Union[Error, "Pool"]: - pools = cls.search(query={"name": [name]}) - if not pools: - return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"]) - - if len(pools) != 1: - return Error( - code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"] - ) - pool = pools[0] - return pool - - @classmethod - def search_states(cls, *, states: Optional[List[PoolState]] = None) -> List["Pool"]: - query: QueryFilter = {} - if states: - query["state"] = states - return cls.search(query=query) - - def set_shutdown(self, now: bool) -> None: - if self.state in [PoolState.halt, PoolState.shutdown]: - return - - if now: - self.state = PoolState.halt - else: - self.state = PoolState.shutdown - - self.save() - - def shutdown(self) -> None: - """ shutdown allows nodes to finish current work then delete """ - scalesets = Scaleset.search_by_pool(self.name) - nodes = Node.search(query={"pool_name": [self.name]}) - if not scalesets and not nodes: - logging.info("pool stopped, deleting: %s", self.name) - - self.state = PoolState.halt - self.delete() - return - - for scaleset in scalesets: - scaleset.set_shutdown(now=False) - - for node in nodes: - node.set_shutdown() - - self.save() - - def halt(self) -> None: - """ halt the pool immediately """ - - scalesets = Scaleset.search_by_pool(self.name) - nodes = Node.search(query={"pool_name": [self.name]}) - if not scalesets and not nodes: - delete_queue(self.get_pool_queue(), StorageType.corpus) - logging.info("pool stopped, deleting: %s", self.name) - self.state = PoolState.halt - self.delete() - return - - for scaleset in scalesets: - scaleset.state = ScalesetState.halt - scaleset.save() - - for node in nodes: - node.set_halt() - - self.save() - - @classmethod - def key_fields(cls) -> Tuple[str, str]: - return ("name", "pool_id") - - def delete(self) -> None: - super().delete() - send_event(EventPoolDeleted(pool_name=self.name)) - - -class Scaleset(BASE_SCALESET, ORMMixin): - def save_exclude(self) -> Optional[MappingIntStrAny]: - return {"nodes": ...} - - def telemetry_include(self) -> Optional[MappingIntStrAny]: - return { - "scaleset_id": ..., - "os": ..., - "vm_sku": ..., - "size": ..., - "spot_instances": ..., - } - - @classmethod - def create( - cls, - *, - pool_name: PoolName, - vm_sku: str, - image: str, - region: Region, - size: int, - spot_instances: bool, - tags: Dict[str, str], - client_id: Optional[UUID] = None, - client_object_id: Optional[UUID] = None, - ) -> "Scaleset": - entry = cls( - pool_name=pool_name, - vm_sku=vm_sku, - image=image, - region=region, - size=size, - spot_instances=spot_instances, - auth=build_auth(), - client_id=client_id, - client_object_id=client_object_id, - tags=tags, - ) - entry.save() - send_event( - EventScalesetCreated( - scaleset_id=entry.scaleset_id, - pool_name=entry.pool_name, - vm_sku=vm_sku, - image=image, - region=region, - size=size, - ) - ) - return entry - - @classmethod - def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]: - return cls.search(query={"pool_name": [pool_name]}) - - @classmethod - def get_by_id(cls, scaleset_id: UUID) -> Union[Error, "Scaleset"]: - scalesets = cls.search(query={"scaleset_id": [scaleset_id]}) - if not scalesets: - return Error( - code=ErrorCode.INVALID_REQUEST, errors=["unable to find scaleset"] - ) - - if len(scalesets) != 1: - return Error( - code=ErrorCode.INVALID_REQUEST, errors=["error identifying scaleset"] - ) - scaleset = scalesets[0] - return scaleset - - @classmethod - def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]: - return cls.search(query={"client_object_id": [object_id]}) - - def set_failed(self, error: Error) -> None: - if self.error is not None: - return - - self.error = error - self.state = ScalesetState.creation_failed - self.save() - - send_event( - EventScalesetFailed( - scaleset_id=self.scaleset_id, pool_name=self.pool_name, error=self.error - ) - ) - - def init(self) -> None: - logging.info("scaleset init: %s", self.scaleset_id) - - ScalesetShrinkQueue(self.scaleset_id).create() - - # Handle the race condition between a pool being deleted and a - # scaleset being added to the pool. - pool = Pool.get_by_name(self.pool_name) - if isinstance(pool, Error): - self.set_failed(pool) - return - - if pool.state == PoolState.init: - logging.info( - "scaleset waiting for pool: %s - %s", self.pool_name, self.scaleset_id - ) - elif pool.state == PoolState.running: - image_os = get_os(self.region, self.image) - if isinstance(image_os, Error): - self.set_failed(image_os) - return - - elif image_os != pool.os: - error = Error( - code=ErrorCode.INVALID_REQUEST, - errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)], - ) - self.set_failed(error) - return - else: - self.state = ScalesetState.setup - else: - self.state = ScalesetState.setup - - self.save() - - def setup(self) -> None: - # TODO: How do we pass in SSH configs for Windows? Previously - # This was done as part of the generated per-task setup script. - logging.info("scaleset setup: %s", self.scaleset_id) - - network = Network(self.region) - network_id = network.get_id() - if not network_id: - logging.info("creating network: %s", self.region) - result = network.create() - if isinstance(result, Error): - self.set_failed(result) - return - self.save() - return - - if self.auth is None: - error = Error( - code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"] - ) - self.set_failed(error) - return - - vmss = get_vmss(self.scaleset_id) - if vmss is None: - pool = Pool.get_by_name(self.pool_name) - if isinstance(pool, Error): - self.set_failed(pool) - return - - logging.info("creating scaleset: %s", self.scaleset_id) - extensions = fuzz_extensions(pool, self) - result = create_vmss( - self.region, - self.scaleset_id, - self.vm_sku, - self.size, - self.image, - network_id, - self.spot_instances, - extensions, - self.auth.password, - self.auth.public_key, - self.tags, - ) - if isinstance(result, Error): - self.set_failed(result) - return - else: - logging.info("creating scaleset: %s", self.scaleset_id) - elif vmss.provisioning_state == "Creating": - logging.info("Waiting on scaleset creation: %s", self.scaleset_id) - self.try_set_identity(vmss) - else: - logging.info("scaleset running: %s", self.scaleset_id) - identity_result = self.try_set_identity(vmss) - if identity_result: - self.set_failed(identity_result) - return - else: - self.state = ScalesetState.running - self.save() - - def try_set_identity(self, vmss: Any) -> Optional[Error]: - def get_error() -> Error: - return Error( - code=ErrorCode.VM_CREATE_FAILED, - errors=[ - "The scaleset is expected to have exactly 1 user assigned identity" - ], - ) - - if self.client_object_id: - return None - if ( - vmss.identity - and vmss.identity.user_assigned_identities - and (len(vmss.identity.user_assigned_identities) != 1) - ): - return get_error() - - user_assinged_identities = list(vmss.identity.user_assigned_identities.values()) - - if user_assinged_identities[0].principal_id: - self.client_object_id = user_assinged_identities[0].principal_id - return None - else: - return get_error() - - # result = 'did I modify the scaleset in azure' - def cleanup_nodes(self) -> bool: - if self.state == ScalesetState.halt: - logging.info("halting scaleset: %s", self.scaleset_id) - self.halt() - return True - - Node.reimage_long_lived_nodes(self.scaleset_id) - - to_reimage = [] - to_delete = [] - - # ground truth of existing nodes - azure_nodes = list_instance_ids(self.scaleset_id) - - nodes = Node.search_states(scaleset_id=self.scaleset_id) - - # Nodes do not exists in scalesets but in table due to unknown failure - for node in nodes: - if node.machine_id not in azure_nodes: - logging.info( - "no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id - ) - node.delete() - - existing_nodes = [x for x in nodes if x.machine_id in azure_nodes] - nodes_to_reset = [ - x for x in existing_nodes if x.state in NodeState.ready_for_reset() - ] - - for node in nodes_to_reset: - if node.delete_requested: - to_delete.append(node) - else: - if ScalesetShrinkQueue(self.scaleset_id).should_shrink(): - node.set_halt() - to_delete.append(node) - elif not node.reimage_queued: - # only add nodes that are not already set to reschedule - to_reimage.append(node) - - dead_nodes = Node.get_dead_nodes(self.scaleset_id, NODE_EXPIRATION_TIME) - for node in dead_nodes: - node.set_halt() - to_reimage.append(node) - - # Perform operations until they fail due to scaleset getting locked - try: - if to_delete: - logging.info( - "deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete) - ) - self.delete_nodes(to_delete) - for node in to_delete: - node.set_halt() - - if to_reimage: - self.reimage_nodes(to_reimage) - except UnableToUpdate: - logging.info("scaleset update already in progress: %s", self.scaleset_id) - - return bool(to_reimage) or bool(to_delete) - - def _resize_equal(self) -> None: - # NOTE: this is the only place we reset to the 'running' state. - # This ensures that our idea of scaleset size agrees with Azure - node_count = len(Node.search_states(scaleset_id=self.scaleset_id)) - if node_count == self.size: - logging.info("resize finished: %s", self.scaleset_id) - self.state = ScalesetState.running - self.save() - return - else: - logging.info( - "resize is finished, waiting for nodes to check in: " - "%s (%d of %d nodes checked in)", - self.scaleset_id, - node_count, - self.size, - ) - return - - def _resize_grow(self) -> None: - try: - resize_vmss(self.scaleset_id, self.size) - except UnableToUpdate: - logging.info("scaleset is mid-operation already") - return - - def _resize_shrink(self, to_remove: int) -> None: - queue = ScalesetShrinkQueue(self.scaleset_id) - for _ in range(to_remove): - queue.add_entry() - - def resize(self) -> None: - # no longer needing to resize - if self.state != ScalesetState.resize: - return - - logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size) - - # reset the node delete queue - ScalesetShrinkQueue(self.scaleset_id).clear() - - # just in case, always ensure size is within max capacity - self.size = min(self.size, self.max_size()) - - # Treat Azure knowledge of the size of the scaleset as "ground truth" - size = get_vmss_size(self.scaleset_id) - if size is None: - logging.info("scaleset is unavailable: %s", self.scaleset_id) - return - - if size == self.size: - self._resize_equal() - elif self.size > size: - self._resize_grow() - else: - self._resize_shrink(size - self.size) - - def delete_nodes(self, nodes: List[Node]) -> None: - if not nodes: - logging.debug("no nodes to delete") - return - - if self.state == ScalesetState.halt: - logging.debug("scaleset delete will delete node: %s", self.scaleset_id) - return - - machine_ids = [] - for node in nodes: - if node.debug_keep_node: - logging.warning( - "delete manually overridden %s:%s", - self.scaleset_id, - node.machine_id, - ) - else: - machine_ids.append(node.machine_id) - - logging.info("deleting %s:%s", self.scaleset_id, machine_ids) - delete_vmss_nodes(self.scaleset_id, machine_ids) - - def reimage_nodes(self, nodes: List[Node]) -> None: - if not nodes: - logging.debug("no nodes to reimage") - return - - if self.state == ScalesetState.shutdown: - self.delete_nodes(nodes) - return - - if self.state == ScalesetState.halt: - logging.debug("scaleset delete will delete node: %s", self.scaleset_id) - return - - machine_ids = [] - for node in nodes: - if node.debug_keep_node: - logging.warning( - "reimage manually overridden %s:%s", - self.scaleset_id, - node.machine_id, - ) - else: - machine_ids.append(node.machine_id) - - result = reimage_vmss_nodes(self.scaleset_id, machine_ids) - if isinstance(result, Error): - raise Exception( - "unable to reimage nodes: %s:%s - %s" - % (self.scaleset_id, machine_ids, result) - ) - for node in nodes: - node.reimage_queued = True - node.save() - - def set_shutdown(self, now: bool) -> None: - if self.state in [ScalesetState.halt, ScalesetState.shutdown]: - return - - if now: - self.state = ScalesetState.halt - else: - self.state = ScalesetState.shutdown - - self.save() - - def shutdown(self) -> None: - size = get_vmss_size(self.scaleset_id) - logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size) - nodes = Node.search_states(scaleset_id=self.scaleset_id) - for node in nodes: - node.set_shutdown() - if size is None or size == 0: - self.halt() - - def halt(self) -> None: - ScalesetShrinkQueue(self.scaleset_id).delete() - - for node in Node.search_states(scaleset_id=self.scaleset_id): - logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id) - node.delete() - - logging.info("scaleset delete starting: %s", self.scaleset_id) - if delete_vmss(self.scaleset_id): - logging.info("scaleset deleted: %s", self.scaleset_id) - self.delete() - else: - self.save() - - @classmethod - def scaleset_max_size(cls, image: str) -> int: - # https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/ - # virtual-machine-scale-sets-placement-groups#checklist-for-using-large-scale-sets - if image.startswith("/"): - return 600 - else: - return 1000 - - def max_size(self) -> int: - return Scaleset.scaleset_max_size(self.image) - - @classmethod - def search_states( - cls, *, states: Optional[List[ScalesetState]] = None - ) -> List["Scaleset"]: - query: QueryFilter = {} - if states: - query["state"] = states - return cls.search(query=query) - - def update_nodes(self) -> None: - # Be in at-least 'setup' before checking for the list of VMs - if self.state == ScalesetState.init: - return - - nodes = Node.search_states(scaleset_id=self.scaleset_id) - azure_nodes = list_instance_ids(self.scaleset_id) - - self.nodes = [] - - for (machine_id, instance_id) in azure_nodes.items(): - node_state: Optional[ScalesetNodeState] = None - for node in nodes: - if node.machine_id == machine_id: - node_state = ScalesetNodeState( - machine_id=machine_id, - instance_id=instance_id, - state=node.state, - ) - break - if not node_state: - node_state = ScalesetNodeState( - machine_id=machine_id, - instance_id=instance_id, - ) - self.nodes.append(node_state) - - def update_configs(self) -> None: - if not self.needs_config_update: - logging.debug("config update not needed: %s", self.scaleset_id) - - logging.info("updating scaleset configs: %s", self.scaleset_id) - - pool = Pool.get_by_name(self.pool_name) - if isinstance(pool, Error): - logging.error( - "unable to find pool during config update: %s - %s", - self.scaleset_id, - pool, - ) - self.set_failed(pool) - return - - extensions = fuzz_extensions(pool, self) - try: - update_extensions(self.scaleset_id, extensions) - self.needs_config_update = False - self.save() - except UnableToUpdate: - logging.debug( - "unable to update configs, update already in progress: %s", - self.scaleset_id, - ) - - @classmethod - def key_fields(cls) -> Tuple[str, str]: - return ("pool_name", "scaleset_id") - - def delete(self) -> None: - super().delete() - send_event( - EventScalesetDeleted(scaleset_id=self.scaleset_id, pool_name=self.pool_name) - ) - - -class ShrinkEntry(BaseModel): - shrink_id: UUID = Field(default_factory=uuid4) - - -class ScalesetShrinkQueue: - def __init__(self, scaleset_id: UUID): - self.scaleset_id = scaleset_id - - def queue_name(self) -> str: - return "to-shrink-%s" % self.scaleset_id.hex - - def clear(self) -> None: - clear_queue(self.queue_name(), StorageType.config) - - def create(self) -> None: - create_queue(self.queue_name(), StorageType.config) - - def delete(self) -> None: - delete_queue(self.queue_name(), StorageType.config) - - def add_entry(self) -> None: - queue_object(self.queue_name(), ShrinkEntry(), StorageType.config) - - def should_shrink(self) -> bool: - return remove_first_message(self.queue_name(), StorageType.config) diff --git a/src/api-service/__app__/onefuzzlib/tasks/main.py b/src/api-service/__app__/onefuzzlib/tasks/main.py index 03fff1501..230c9c2c5 100644 --- a/src/api-service/__app__/onefuzzlib/tasks/main.py +++ b/src/api-service/__app__/onefuzzlib/tasks/main.py @@ -24,8 +24,10 @@ from ..azure.queue import create_queue, delete_queue from ..azure.storage import StorageType from ..events import send_event from ..orm import MappingIntStrAny, ORMMixin, QueryFilter -from ..pools import Node, Pool, Scaleset from ..proxy_forward import ProxyForward +from ..workers.nodes import Node +from ..workers.pools import Pool +from ..workers.scalesets import Scaleset class Task(BASE_TASK, ORMMixin): diff --git a/src/api-service/__app__/onefuzzlib/tasks/scheduler.py b/src/api-service/__app__/onefuzzlib/tasks/scheduler.py index 94ed9b285..e70b310f3 100644 --- a/src/api-service/__app__/onefuzzlib/tasks/scheduler.py +++ b/src/api-service/__app__/onefuzzlib/tasks/scheduler.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from ..azure.containers import blob_exists, get_container_sas_url from ..azure.storage import StorageType -from ..pools import Pool +from ..workers.pools import Pool from .config import build_task_config, get_setup_container from .main import Task diff --git a/src/api-service/__app__/onefuzzlib/updates.py b/src/api-service/__app__/onefuzzlib/updates.py index 91b74dcdf..16ce3666b 100644 --- a/src/api-service/__app__/onefuzzlib/updates.py +++ b/src/api-service/__app__/onefuzzlib/updates.py @@ -60,10 +60,12 @@ def queue_update( def execute_update(update: Update) -> None: from .jobs import Job from .orm import ORMMixin - from .pools import Node, Pool, Scaleset from .proxy import Proxy from .repro import Repro from .tasks.main import Task + from .workers.nodes import Node + from .workers.pools import Pool + from .workers.scalesets import Scaleset update_objects: Dict[UpdateType, Type[ORMMixin]] = { UpdateType.Task: Task, diff --git a/src/api-service/__app__/onefuzzlib/workers/__init__.py b/src/api-service/__app__/onefuzzlib/workers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/api-service/__app__/onefuzzlib/workers/nodes.py b/src/api-service/__app__/onefuzzlib/workers/nodes.py new file mode 100644 index 000000000..f9d2844b3 --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/workers/nodes.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import datetime +import logging +from typing import List, Optional, Tuple +from uuid import UUID + +from onefuzztypes.enums import ErrorCode, NodeState +from onefuzztypes.events import ( + EventNodeCreated, + EventNodeDeleted, + EventNodeStateUpdated, +) +from onefuzztypes.models import Error +from onefuzztypes.models import Node as BASE_NODE +from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey +from onefuzztypes.models import NodeTasks as BASE_NODE_TASK +from onefuzztypes.models import Result, StopNodeCommand +from onefuzztypes.primitives import PoolName +from pydantic import Field + +from ..__version__ import __version__ +from ..azure.vmss import get_instance_id +from ..events import send_event +from ..orm import MappingIntStrAny, ORMMixin, QueryFilter + +NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) +NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) + +# Future work: +# +# Enabling autoscaling for the scalesets based on the pool work queues. +# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics + + +class Node(BASE_NODE, ORMMixin): + # should only be set by Scaleset.reimage_nodes + # should only be unset during agent_registration POST + reimage_queued: bool = Field(default=False) + + @classmethod + def create( + cls, + *, + pool_name: PoolName, + machine_id: UUID, + scaleset_id: Optional[UUID], + version: str, + ) -> "Node": + node = cls( + pool_name=pool_name, + machine_id=machine_id, + scaleset_id=scaleset_id, + version=version, + ) + node.save() + send_event( + EventNodeCreated( + machine_id=node.machine_id, + scaleset_id=node.scaleset_id, + pool_name=node.pool_name, + ) + ) + return node + + @classmethod + def search_states( + cls, + *, + scaleset_id: Optional[UUID] = None, + states: Optional[List[NodeState]] = None, + pool_name: Optional[str] = None, + ) -> List["Node"]: + query: QueryFilter = {} + if scaleset_id: + query["scaleset_id"] = [scaleset_id] + if states: + query["state"] = states + if pool_name: + query["pool_name"] = [pool_name] + return cls.search(query=query) + + @classmethod + def search_outdated( + cls, + *, + scaleset_id: Optional[UUID] = None, + states: Optional[List[NodeState]] = None, + pool_name: Optional[str] = None, + exclude_update_scheduled: bool = False, + num_results: Optional[int] = None, + ) -> List["Node"]: + query: QueryFilter = {} + if scaleset_id: + query["scaleset_id"] = [scaleset_id] + if states: + query["state"] = states + if pool_name: + query["pool_name"] = [pool_name] + + if exclude_update_scheduled: + query["reimage_requested"] = [False] + query["delete_requested"] = [False] + + # azure table query always return false when the column does not exist + # We write the query this way to allow us to get the nodes where the + # version is not defined as well as the nodes with a mismatched version + version_query = "not (version eq '%s')" % __version__ + return cls.search( + query=query, raw_unchecked_filter=version_query, num_results=num_results + ) + + @classmethod + def mark_outdated_nodes(cls) -> None: + # ony update 500 nodes at a time to mitigate timeout issues + outdated = cls.search_outdated(exclude_update_scheduled=True, num_results=500) + for node in outdated: + logging.info( + "node is outdated: %s - node_version:%s api_version:%s", + node.machine_id, + node.version, + __version__, + ) + if node.version == "1.0.0": + node.to_reimage(done=True) + else: + node.to_reimage() + + @classmethod + def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]: + nodes = cls.search(query={"machine_id": [machine_id]}) + if not nodes: + return None + + if len(nodes) != 1: + return None + return nodes[0] + + @classmethod + def key_fields(cls) -> Tuple[str, str]: + return ("pool_name", "machine_id") + + def save_exclude(self) -> Optional[MappingIntStrAny]: + return {"tasks": ...} + + def telemetry_include(self) -> Optional[MappingIntStrAny]: + return { + "machine_id": ..., + "state": ..., + "scaleset_id": ..., + } + + def scaleset_node_exists(self) -> bool: + if self.scaleset_id is None: + return False + + from .scalesets import Scaleset + + scaleset = Scaleset.get_by_id(self.scaleset_id) + if not isinstance(scaleset, Scaleset): + return False + + instance_id = get_instance_id(scaleset.scaleset_id, self.machine_id) + return isinstance(instance_id, str) + + @classmethod + def stop_task(cls, task_id: UUID) -> None: + # For now, this just re-images the node. Eventually, this + # should send a message to the node to let the agent shut down + # gracefully + nodes = NodeTasks.get_nodes_by_task_id(task_id) + for node in nodes: + if node.state not in NodeState.ready_for_reset(): + logging.info( + "stopping machine_id:%s running task:%s", + node.machine_id, + task_id, + ) + node.stop() + + def mark_tasks_stopped_early(self) -> None: + from ..tasks.main import Task + + for entry in NodeTasks.get_by_machine_id(self.machine_id): + task = Task.get_by_task_id(entry.task_id) + if isinstance(task, Task): + task.mark_failed( + Error( + code=ErrorCode.TASK_FAILED, + errors=["node reimaged during task execution"], + ) + ) + + def could_shrink_scaleset(self) -> bool: + from .scalesets import ScalesetShrinkQueue + + if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink(): + return True + return False + + def can_process_new_work(self) -> bool: + if self.is_outdated(): + logging.info( + "can_schedule old version machine_id:%s version:%s", + self.machine_id, + self.version, + ) + self.stop() + return False + + if self.state in NodeState.ready_for_reset(): + logging.info( + "can_schedule node is set for reset. machine_id:%s", self.machine_id + ) + return False + + if self.delete_requested: + logging.info( + "can_schedule is set to be deleted. machine_id:%s", + self.machine_id, + ) + self.stop() + return False + + if self.reimage_requested: + logging.info( + "can_schedule is set to be reimaged. machine_id:%s", + self.machine_id, + ) + self.stop() + return False + + if self.could_shrink_scaleset(): + self.set_halt() + logging.info("node scheduled to shrink. machine_id:%s", self.machine_id) + return False + + return True + + def is_outdated(self) -> bool: + return self.version != __version__ + + def send_message(self, message: NodeCommand) -> None: + NodeMessage( + machine_id=self.machine_id, + message=message, + ).save() + + def to_reimage(self, done: bool = False) -> None: + if done: + if self.state not in NodeState.ready_for_reset(): + self.state = NodeState.done + + if not self.reimage_requested and not self.delete_requested: + logging.info("setting reimage_requested: %s", self.machine_id) + self.reimage_requested = True + self.save() + + def add_ssh_public_key(self, public_key: str) -> Result[None]: + if self.scaleset_id is None: + return Error( + code=ErrorCode.INVALID_REQUEST, + errors=["only able to add ssh keys to scaleset nodes"], + ) + + if not public_key.endswith("\n"): + public_key += "\n" + + self.send_message( + NodeCommand(add_ssh_key=NodeCommandAddSshKey(public_key=public_key)) + ) + return None + + def stop(self) -> None: + self.to_reimage() + self.send_message(NodeCommand(stop=StopNodeCommand())) + + def set_shutdown(self) -> None: + # don't give out more work to the node, but let it finish existing work + logging.info("setting delete_requested: %s", self.machine_id) + self.delete_requested = True + self.save() + + def set_halt(self) -> None: + """ Tell the node to stop everything. """ + self.set_shutdown() + self.stop() + self.set_state(NodeState.halt) + + @classmethod + def get_dead_nodes( + cls, scaleset_id: UUID, expiration_period: datetime.timedelta + ) -> List["Node"]: + time_filter = "heartbeat lt datetime'%s'" % ( + (datetime.datetime.utcnow() - expiration_period).isoformat() + ) + return cls.search( + query={"scaleset_id": [scaleset_id]}, + raw_unchecked_filter=time_filter, + ) + + @classmethod + def reimage_long_lived_nodes(cls, scaleset_id: UUID) -> None: + """ + Mark any excessively long lived node to be re-imaged. + + This helps keep nodes on scalesets that use `latest` OS image SKUs + reasonably up-to-date with OS patches without disrupting running + fuzzing tasks with patch reboot cycles. + """ + time_filter = "Timestamp lt datetime'%s'" % ( + (datetime.datetime.utcnow() - NODE_REIMAGE_TIME).isoformat() + ) + # skip any nodes already marked for reimage/deletion + for node in cls.search( + query={ + "scaleset_id": [scaleset_id], + "reimage_requested": [False], + "delete_requested": [False], + }, + raw_unchecked_filter=time_filter, + ): + node.to_reimage() + + def set_state(self, state: NodeState) -> None: + if self.state != state: + self.state = state + send_event( + EventNodeStateUpdated( + machine_id=self.machine_id, + pool_name=self.pool_name, + scaleset_id=self.scaleset_id, + state=state, + ) + ) + + self.save() + + def delete(self) -> None: + self.mark_tasks_stopped_early() + NodeTasks.clear_by_machine_id(self.machine_id) + NodeMessage.clear_messages(self.machine_id) + super().delete() + send_event( + EventNodeDeleted( + machine_id=self.machine_id, + pool_name=self.pool_name, + scaleset_id=self.scaleset_id, + ) + ) + + +class NodeTasks(BASE_NODE_TASK, ORMMixin): + @classmethod + def key_fields(cls) -> Tuple[str, str]: + return ("machine_id", "task_id") + + def telemetry_include(self) -> Optional[MappingIntStrAny]: + return { + "machine_id": ..., + "task_id": ..., + "state": ..., + } + + @classmethod + def get_nodes_by_task_id(cls, task_id: UUID) -> List["Node"]: + result = [] + for entry in cls.search(query={"task_id": [task_id]}): + node = Node.get_by_machine_id(entry.machine_id) + if node: + result.append(node) + return result + + @classmethod + def get_node_assignments(cls, task_id: UUID) -> List[NodeAssignment]: + result = [] + for entry in cls.search(query={"task_id": [task_id]}): + node = Node.get_by_machine_id(entry.machine_id) + if node: + node_assignment = NodeAssignment( + node_id=node.machine_id, + scaleset_id=node.scaleset_id, + state=entry.state, + ) + result.append(node_assignment) + + return result + + @classmethod + def get_by_machine_id(cls, machine_id: UUID) -> List["NodeTasks"]: + return cls.search(query={"machine_id": [machine_id]}) + + @classmethod + def get_by_task_id(cls, task_id: UUID) -> List["NodeTasks"]: + return cls.search(query={"task_id": [task_id]}) + + @classmethod + def clear_by_machine_id(cls, machine_id: UUID) -> None: + logging.info("clearing tasks for node: %s", machine_id) + for entry in cls.get_by_machine_id(machine_id): + entry.delete() + + +# this isn't anticipated to be needed by the client, hence it not +# being in onefuzztypes +class NodeMessage(ORMMixin): + machine_id: UUID + message_id: str = Field(default_factory=datetime.datetime.utcnow().timestamp) + message: NodeCommand + + @classmethod + def key_fields(cls) -> Tuple[str, str]: + return ("machine_id", "message_id") + + @classmethod + def get_messages( + cls, machine_id: UUID, num_results: int = None + ) -> List["NodeMessage"]: + entries: List["NodeMessage"] = cls.search( + query={"machine_id": [machine_id]}, num_results=num_results + ) + return entries + + @classmethod + def clear_messages(cls, machine_id: UUID) -> None: + logging.info("clearing messages for node: %s", machine_id) + messages = cls.get_messages(machine_id) + for message in messages: + message.delete() diff --git a/src/api-service/__app__/onefuzzlib/workers/pools.py b/src/api-service/__app__/onefuzzlib/workers/pools.py new file mode 100644 index 000000000..f8ae1d2bc --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/workers/pools.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import datetime +import logging +from typing import List, Optional, Tuple, Union +from uuid import UUID + +from onefuzztypes.enums import OS, Architecture, ErrorCode, PoolState, ScalesetState +from onefuzztypes.events import EventPoolCreated, EventPoolDeleted +from onefuzztypes.models import AutoScaleConfig, Error +from onefuzztypes.models import Pool as BASE_POOL +from onefuzztypes.models import ( + ScalesetSummary, + WorkSet, + WorkSetSummary, + WorkUnitSummary, +) +from onefuzztypes.primitives import PoolName + +from ..azure.queue import create_queue, delete_queue, peek_queue, queue_object +from ..azure.storage import StorageType +from ..events import send_event +from ..orm import MappingIntStrAny, ORMMixin, QueryFilter + +NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) +NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) + +# Future work: +# +# Enabling autoscaling for the scalesets based on the pool work queues. +# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics + + +class Pool(BASE_POOL, ORMMixin): + @classmethod + def create( + cls, + *, + name: PoolName, + os: OS, + arch: Architecture, + managed: bool, + client_id: Optional[UUID], + autoscale: Optional[AutoScaleConfig], + ) -> "Pool": + entry = cls( + name=name, + os=os, + arch=arch, + managed=managed, + client_id=client_id, + config=None, + autoscale=autoscale, + ) + entry.save() + send_event( + EventPoolCreated( + pool_name=name, + os=os, + arch=arch, + managed=managed, + autoscale=autoscale, + ) + ) + return entry + + def save_exclude(self) -> Optional[MappingIntStrAny]: + return { + "nodes": ..., + "queue": ..., + "work_queue": ..., + "config": ..., + "node_summary": ..., + } + + def export_exclude(self) -> Optional[MappingIntStrAny]: + return { + "etag": ..., + "timestamp": ..., + } + + def telemetry_include(self) -> Optional[MappingIntStrAny]: + return { + "pool_id": ..., + "os": ..., + "state": ..., + "managed": ..., + } + + def populate_scaleset_summary(self) -> None: + from .scalesets import Scaleset + + self.scaleset_summary = [ + ScalesetSummary(scaleset_id=x.scaleset_id, state=x.state) + for x in Scaleset.search_by_pool(self.name) + ] + + def populate_work_queue(self) -> None: + self.work_queue = [] + + # Only populate the work queue summaries if the pool is initialized. We + # can then be sure that the queue is available in the operations below. + if self.state == PoolState.init: + return + + worksets = peek_queue( + self.get_pool_queue(), StorageType.corpus, object_type=WorkSet + ) + + for workset in worksets: + work_units = [ + WorkUnitSummary( + job_id=work_unit.job_id, + task_id=work_unit.task_id, + task_type=work_unit.task_type, + ) + for work_unit in workset.work_units + ] + self.work_queue.append(WorkSetSummary(work_units=work_units)) + + def get_pool_queue(self) -> str: + return "pool-%s" % self.pool_id.hex + + def init(self) -> None: + create_queue(self.get_pool_queue(), StorageType.corpus) + self.state = PoolState.running + self.save() + + def schedule_workset(self, work_set: WorkSet) -> bool: + # Don't schedule work for pools that can't and won't do work. + if self.state in [PoolState.shutdown, PoolState.halt]: + return False + + return queue_object( + self.get_pool_queue(), + work_set, + StorageType.corpus, + ) + + @classmethod + def get_by_id(cls, pool_id: UUID) -> Union[Error, "Pool"]: + pools = cls.search(query={"pool_id": [pool_id]}) + if not pools: + return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"]) + + if len(pools) != 1: + return Error( + code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"] + ) + pool = pools[0] + return pool + + @classmethod + def get_by_name(cls, name: PoolName) -> Union[Error, "Pool"]: + pools = cls.search(query={"name": [name]}) + if not pools: + return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"]) + + if len(pools) != 1: + return Error( + code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"] + ) + pool = pools[0] + return pool + + @classmethod + def search_states(cls, *, states: Optional[List[PoolState]] = None) -> List["Pool"]: + query: QueryFilter = {} + if states: + query["state"] = states + return cls.search(query=query) + + def set_shutdown(self, now: bool) -> None: + if self.state in [PoolState.halt, PoolState.shutdown]: + return + + if now: + self.state = PoolState.halt + else: + self.state = PoolState.shutdown + + self.save() + + def shutdown(self) -> None: + """ shutdown allows nodes to finish current work then delete """ + from .nodes import Node + from .scalesets import Scaleset + + scalesets = Scaleset.search_by_pool(self.name) + nodes = Node.search(query={"pool_name": [self.name]}) + if not scalesets and not nodes: + logging.info("pool stopped, deleting: %s", self.name) + + self.state = PoolState.halt + self.delete() + return + + for scaleset in scalesets: + scaleset.set_shutdown(now=False) + + for node in nodes: + node.set_shutdown() + + self.save() + + def halt(self) -> None: + """ halt the pool immediately """ + + from .nodes import Node + from .scalesets import Scaleset + + scalesets = Scaleset.search_by_pool(self.name) + nodes = Node.search(query={"pool_name": [self.name]}) + if not scalesets and not nodes: + delete_queue(self.get_pool_queue(), StorageType.corpus) + logging.info("pool stopped, deleting: %s", self.name) + self.state = PoolState.halt + self.delete() + return + + for scaleset in scalesets: + scaleset.state = ScalesetState.halt + scaleset.save() + + for node in nodes: + node.set_halt() + + self.save() + + @classmethod + def key_fields(cls) -> Tuple[str, str]: + return ("name", "pool_id") + + def delete(self) -> None: + super().delete() + send_event(EventPoolDeleted(pool_name=self.name)) diff --git a/src/api-service/__app__/onefuzzlib/workers/scalesets.py b/src/api-service/__app__/onefuzzlib/workers/scalesets.py new file mode 100644 index 000000000..eeb9b3161 --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/workers/scalesets.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import datetime +import logging +from typing import Any, Dict, List, Optional, Tuple, Union +from uuid import UUID, uuid4 + +from onefuzztypes.enums import ErrorCode, NodeState, PoolState, ScalesetState +from onefuzztypes.events import ( + EventScalesetCreated, + EventScalesetDeleted, + EventScalesetFailed, +) +from onefuzztypes.models import Error +from onefuzztypes.models import Scaleset as BASE_SCALESET +from onefuzztypes.models import ScalesetNodeState +from onefuzztypes.primitives import PoolName, Region +from pydantic import BaseModel, Field + +from ..azure.auth import build_auth +from ..azure.image import get_os +from ..azure.network import Network +from ..azure.queue import ( + clear_queue, + create_queue, + delete_queue, + queue_object, + remove_first_message, +) +from ..azure.storage import StorageType +from ..azure.vmss import ( + UnableToUpdate, + create_vmss, + delete_vmss, + delete_vmss_nodes, + get_vmss, + get_vmss_size, + list_instance_ids, + reimage_vmss_nodes, + resize_vmss, + update_extensions, +) +from ..events import send_event +from ..extension import fuzz_extensions +from ..orm import MappingIntStrAny, ORMMixin, QueryFilter +from .nodes import Node + +NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) +NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) + +# Future work: +# +# Enabling autoscaling for the scalesets based on the pool work queues. +# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics + + +class Scaleset(BASE_SCALESET, ORMMixin): + def save_exclude(self) -> Optional[MappingIntStrAny]: + return {"nodes": ...} + + def telemetry_include(self) -> Optional[MappingIntStrAny]: + return { + "scaleset_id": ..., + "os": ..., + "vm_sku": ..., + "size": ..., + "spot_instances": ..., + } + + @classmethod + def create( + cls, + *, + pool_name: PoolName, + vm_sku: str, + image: str, + region: Region, + size: int, + spot_instances: bool, + tags: Dict[str, str], + client_id: Optional[UUID] = None, + client_object_id: Optional[UUID] = None, + ) -> "Scaleset": + entry = cls( + pool_name=pool_name, + vm_sku=vm_sku, + image=image, + region=region, + size=size, + spot_instances=spot_instances, + auth=build_auth(), + client_id=client_id, + client_object_id=client_object_id, + tags=tags, + ) + entry.save() + send_event( + EventScalesetCreated( + scaleset_id=entry.scaleset_id, + pool_name=entry.pool_name, + vm_sku=vm_sku, + image=image, + region=region, + size=size, + ) + ) + return entry + + @classmethod + def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]: + return cls.search(query={"pool_name": [pool_name]}) + + @classmethod + def get_by_id(cls, scaleset_id: UUID) -> Union[Error, "Scaleset"]: + scalesets = cls.search(query={"scaleset_id": [scaleset_id]}) + if not scalesets: + return Error( + code=ErrorCode.INVALID_REQUEST, errors=["unable to find scaleset"] + ) + + if len(scalesets) != 1: + return Error( + code=ErrorCode.INVALID_REQUEST, errors=["error identifying scaleset"] + ) + scaleset = scalesets[0] + return scaleset + + @classmethod + def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]: + return cls.search(query={"client_object_id": [object_id]}) + + def set_failed(self, error: Error) -> None: + if self.error is not None: + return + + self.error = error + self.state = ScalesetState.creation_failed + self.save() + + send_event( + EventScalesetFailed( + scaleset_id=self.scaleset_id, pool_name=self.pool_name, error=self.error + ) + ) + + def init(self) -> None: + from .pools import Pool + + logging.info("scaleset init: %s", self.scaleset_id) + + ScalesetShrinkQueue(self.scaleset_id).create() + + # Handle the race condition between a pool being deleted and a + # scaleset being added to the pool. + pool = Pool.get_by_name(self.pool_name) + if isinstance(pool, Error): + self.set_failed(pool) + return + + if pool.state == PoolState.init: + logging.info( + "scaleset waiting for pool: %s - %s", self.pool_name, self.scaleset_id + ) + elif pool.state == PoolState.running: + image_os = get_os(self.region, self.image) + if isinstance(image_os, Error): + self.set_failed(image_os) + return + + elif image_os != pool.os: + error = Error( + code=ErrorCode.INVALID_REQUEST, + errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)], + ) + self.set_failed(error) + return + else: + self.state = ScalesetState.setup + else: + self.state = ScalesetState.setup + + self.save() + + def setup(self) -> None: + from .pools import Pool + + # TODO: How do we pass in SSH configs for Windows? Previously + # This was done as part of the generated per-task setup script. + logging.info("scaleset setup: %s", self.scaleset_id) + + network = Network(self.region) + network_id = network.get_id() + if not network_id: + logging.info("creating network: %s", self.region) + result = network.create() + if isinstance(result, Error): + self.set_failed(result) + return + self.save() + return + + if self.auth is None: + error = Error( + code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"] + ) + self.set_failed(error) + return + + vmss = get_vmss(self.scaleset_id) + if vmss is None: + pool = Pool.get_by_name(self.pool_name) + if isinstance(pool, Error): + self.set_failed(pool) + return + + logging.info("creating scaleset: %s", self.scaleset_id) + extensions = fuzz_extensions(pool, self) + result = create_vmss( + self.region, + self.scaleset_id, + self.vm_sku, + self.size, + self.image, + network_id, + self.spot_instances, + extensions, + self.auth.password, + self.auth.public_key, + self.tags, + ) + if isinstance(result, Error): + self.set_failed(result) + return + else: + logging.info("creating scaleset: %s", self.scaleset_id) + elif vmss.provisioning_state == "Creating": + logging.info("Waiting on scaleset creation: %s", self.scaleset_id) + self.try_set_identity(vmss) + else: + logging.info("scaleset running: %s", self.scaleset_id) + identity_result = self.try_set_identity(vmss) + if identity_result: + self.set_failed(identity_result) + return + else: + self.state = ScalesetState.running + self.save() + + def try_set_identity(self, vmss: Any) -> Optional[Error]: + def get_error() -> Error: + return Error( + code=ErrorCode.VM_CREATE_FAILED, + errors=[ + "The scaleset is expected to have exactly 1 user assigned identity" + ], + ) + + if self.client_object_id: + return None + if ( + vmss.identity + and vmss.identity.user_assigned_identities + and (len(vmss.identity.user_assigned_identities) != 1) + ): + return get_error() + + user_assinged_identities = list(vmss.identity.user_assigned_identities.values()) + + if user_assinged_identities[0].principal_id: + self.client_object_id = user_assinged_identities[0].principal_id + return None + else: + return get_error() + + # result = 'did I modify the scaleset in azure' + def cleanup_nodes(self) -> bool: + if self.state == ScalesetState.halt: + logging.info("halting scaleset: %s", self.scaleset_id) + self.halt() + return True + + Node.reimage_long_lived_nodes(self.scaleset_id) + + to_reimage = [] + to_delete = [] + + # ground truth of existing nodes + azure_nodes = list_instance_ids(self.scaleset_id) + + nodes = Node.search_states(scaleset_id=self.scaleset_id) + + # Nodes do not exists in scalesets but in table due to unknown failure + for node in nodes: + if node.machine_id not in azure_nodes: + logging.info( + "no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id + ) + node.delete() + + existing_nodes = [x for x in nodes if x.machine_id in azure_nodes] + nodes_to_reset = [ + x for x in existing_nodes if x.state in NodeState.ready_for_reset() + ] + + for node in nodes_to_reset: + if node.delete_requested: + to_delete.append(node) + else: + if ScalesetShrinkQueue(self.scaleset_id).should_shrink(): + node.set_halt() + to_delete.append(node) + elif not node.reimage_queued: + # only add nodes that are not already set to reschedule + to_reimage.append(node) + + dead_nodes = Node.get_dead_nodes(self.scaleset_id, NODE_EXPIRATION_TIME) + for node in dead_nodes: + node.set_halt() + to_reimage.append(node) + + # Perform operations until they fail due to scaleset getting locked + try: + if to_delete: + logging.info( + "deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete) + ) + self.delete_nodes(to_delete) + for node in to_delete: + node.set_halt() + + if to_reimage: + self.reimage_nodes(to_reimage) + except UnableToUpdate: + logging.info("scaleset update already in progress: %s", self.scaleset_id) + + return bool(to_reimage) or bool(to_delete) + + def _resize_equal(self) -> None: + # NOTE: this is the only place we reset to the 'running' state. + # This ensures that our idea of scaleset size agrees with Azure + node_count = len(Node.search_states(scaleset_id=self.scaleset_id)) + if node_count == self.size: + logging.info("resize finished: %s", self.scaleset_id) + self.state = ScalesetState.running + self.save() + return + else: + logging.info( + "resize is finished, waiting for nodes to check in: " + "%s (%d of %d nodes checked in)", + self.scaleset_id, + node_count, + self.size, + ) + return + + def _resize_grow(self) -> None: + try: + resize_vmss(self.scaleset_id, self.size) + except UnableToUpdate: + logging.info("scaleset is mid-operation already") + return + + def _resize_shrink(self, to_remove: int) -> None: + queue = ScalesetShrinkQueue(self.scaleset_id) + for _ in range(to_remove): + queue.add_entry() + + def resize(self) -> None: + # no longer needing to resize + if self.state != ScalesetState.resize: + return + + logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size) + + # reset the node delete queue + ScalesetShrinkQueue(self.scaleset_id).clear() + + # just in case, always ensure size is within max capacity + self.size = min(self.size, self.max_size()) + + # Treat Azure knowledge of the size of the scaleset as "ground truth" + size = get_vmss_size(self.scaleset_id) + if size is None: + logging.info("scaleset is unavailable: %s", self.scaleset_id) + return + + if size == self.size: + self._resize_equal() + elif self.size > size: + self._resize_grow() + else: + self._resize_shrink(size - self.size) + + def delete_nodes(self, nodes: List[Node]) -> None: + if not nodes: + logging.debug("no nodes to delete") + return + + if self.state == ScalesetState.halt: + logging.debug("scaleset delete will delete node: %s", self.scaleset_id) + return + + machine_ids = [] + for node in nodes: + if node.debug_keep_node: + logging.warning( + "delete manually overridden %s:%s", + self.scaleset_id, + node.machine_id, + ) + else: + machine_ids.append(node.machine_id) + + logging.info("deleting %s:%s", self.scaleset_id, machine_ids) + delete_vmss_nodes(self.scaleset_id, machine_ids) + + def reimage_nodes(self, nodes: List[Node]) -> None: + if not nodes: + logging.debug("no nodes to reimage") + return + + if self.state == ScalesetState.shutdown: + self.delete_nodes(nodes) + return + + if self.state == ScalesetState.halt: + logging.debug("scaleset delete will delete node: %s", self.scaleset_id) + return + + machine_ids = [] + for node in nodes: + if node.debug_keep_node: + logging.warning( + "reimage manually overridden %s:%s", + self.scaleset_id, + node.machine_id, + ) + else: + machine_ids.append(node.machine_id) + + result = reimage_vmss_nodes(self.scaleset_id, machine_ids) + if isinstance(result, Error): + raise Exception( + "unable to reimage nodes: %s:%s - %s" + % (self.scaleset_id, machine_ids, result) + ) + for node in nodes: + node.reimage_queued = True + node.save() + + def set_shutdown(self, now: bool) -> None: + if self.state in [ScalesetState.halt, ScalesetState.shutdown]: + return + + if now: + self.state = ScalesetState.halt + else: + self.state = ScalesetState.shutdown + + self.save() + + def shutdown(self) -> None: + size = get_vmss_size(self.scaleset_id) + logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size) + nodes = Node.search_states(scaleset_id=self.scaleset_id) + for node in nodes: + node.set_shutdown() + if size is None or size == 0: + self.halt() + + def halt(self) -> None: + ScalesetShrinkQueue(self.scaleset_id).delete() + + for node in Node.search_states(scaleset_id=self.scaleset_id): + logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id) + node.delete() + + logging.info("scaleset delete starting: %s", self.scaleset_id) + if delete_vmss(self.scaleset_id): + logging.info("scaleset deleted: %s", self.scaleset_id) + self.delete() + else: + self.save() + + @classmethod + def scaleset_max_size(cls, image: str) -> int: + # https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/ + # virtual-machine-scale-sets-placement-groups#checklist-for-using-large-scale-sets + if image.startswith("/"): + return 600 + else: + return 1000 + + def max_size(self) -> int: + return Scaleset.scaleset_max_size(self.image) + + @classmethod + def search_states( + cls, *, states: Optional[List[ScalesetState]] = None + ) -> List["Scaleset"]: + query: QueryFilter = {} + if states: + query["state"] = states + return cls.search(query=query) + + def update_nodes(self) -> None: + # Be in at-least 'setup' before checking for the list of VMs + if self.state == ScalesetState.init: + return + + nodes = Node.search_states(scaleset_id=self.scaleset_id) + azure_nodes = list_instance_ids(self.scaleset_id) + + self.nodes = [] + + for (machine_id, instance_id) in azure_nodes.items(): + node_state: Optional[ScalesetNodeState] = None + for node in nodes: + if node.machine_id == machine_id: + node_state = ScalesetNodeState( + machine_id=machine_id, + instance_id=instance_id, + state=node.state, + ) + break + if not node_state: + node_state = ScalesetNodeState( + machine_id=machine_id, + instance_id=instance_id, + ) + self.nodes.append(node_state) + + def update_configs(self) -> None: + from .pools import Pool + + if not self.needs_config_update: + logging.debug("config update not needed: %s", self.scaleset_id) + + logging.info("updating scaleset configs: %s", self.scaleset_id) + + pool = Pool.get_by_name(self.pool_name) + if isinstance(pool, Error): + logging.error( + "unable to find pool during config update: %s - %s", + self.scaleset_id, + pool, + ) + self.set_failed(pool) + return + + extensions = fuzz_extensions(pool, self) + try: + update_extensions(self.scaleset_id, extensions) + self.needs_config_update = False + self.save() + except UnableToUpdate: + logging.debug( + "unable to update configs, update already in progress: %s", + self.scaleset_id, + ) + + @classmethod + def key_fields(cls) -> Tuple[str, str]: + return ("pool_name", "scaleset_id") + + def delete(self) -> None: + super().delete() + send_event( + EventScalesetDeleted(scaleset_id=self.scaleset_id, pool_name=self.pool_name) + ) + + +class ShrinkEntry(BaseModel): + shrink_id: UUID = Field(default_factory=uuid4) + + +class ScalesetShrinkQueue: + def __init__(self, scaleset_id: UUID): + self.scaleset_id = scaleset_id + + def queue_name(self) -> str: + return "to-shrink-%s" % self.scaleset_id.hex + + def clear(self) -> None: + clear_queue(self.queue_name(), StorageType.config) + + def create(self) -> None: + create_queue(self.queue_name(), StorageType.config) + + def delete(self) -> None: + delete_queue(self.queue_name(), StorageType.config) + + def add_entry(self) -> None: + queue_object(self.queue_name(), ShrinkEntry(), StorageType.config) + + def should_shrink(self) -> bool: + return remove_first_message(self.queue_name(), StorageType.config) diff --git a/src/api-service/__app__/pool/__init__.py b/src/api-service/__app__/pool/__init__.py index 4160aa7eb..350cb8172 100644 --- a/src/api-service/__app__/pool/__init__.py +++ b/src/api-service/__app__/pool/__init__.py @@ -23,8 +23,8 @@ from ..onefuzzlib.azure.storage import StorageType from ..onefuzzlib.azure.vmss import list_available_skus from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Pool from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.pools import Pool def set_config(pool: Pool) -> Pool: diff --git a/src/api-service/__app__/proxy/__init__.py b/src/api-service/__app__/proxy/__init__.py index d4bdea682..579647c94 100644 --- a/src/api-service/__app__/proxy/__init__.py +++ b/src/api-service/__app__/proxy/__init__.py @@ -13,10 +13,10 @@ from onefuzztypes.responses import BoolResult, ProxyGetResult from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Scaleset from ..onefuzzlib.proxy import Proxy from ..onefuzzlib.proxy_forward import ProxyForward from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.scalesets import Scaleset def get_result(proxy_forward: ProxyForward, proxy: Optional[Proxy]) -> ProxyGetResult: diff --git a/src/api-service/__app__/queue_node_heartbeat/__init__.py b/src/api-service/__app__/queue_node_heartbeat/__init__.py index e43ccf37e..7733bb2d2 100644 --- a/src/api-service/__app__/queue_node_heartbeat/__init__.py +++ b/src/api-service/__app__/queue_node_heartbeat/__init__.py @@ -12,7 +12,7 @@ from onefuzztypes.models import NodeHeartbeatEntry from pydantic import ValidationError from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Node +from ..onefuzzlib.workers.nodes import Node def main(msg: func.QueueMessage, dashboard: func.Out[str]) -> None: diff --git a/src/api-service/__app__/scaleset/__init__.py b/src/api-service/__app__/scaleset/__init__.py index 8f5e955a8..b62175ed9 100644 --- a/src/api-service/__app__/scaleset/__init__.py +++ b/src/api-service/__app__/scaleset/__init__.py @@ -18,8 +18,9 @@ from ..onefuzzlib.azure.creds import get_base_region, get_regions from ..onefuzzlib.azure.vmss import list_available_skus from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Pool, Scaleset from ..onefuzzlib.request import not_ok, ok, parse_request +from ..onefuzzlib.workers.pools import Pool +from ..onefuzzlib.workers.scalesets import Scaleset def get(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/tasks/__init__.py b/src/api-service/__app__/tasks/__init__.py index 685df7781..1fc276073 100644 --- a/src/api-service/__app__/tasks/__init__.py +++ b/src/api-service/__app__/tasks/__init__.py @@ -12,12 +12,12 @@ from onefuzztypes.responses import BoolResult from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.events import get_events from ..onefuzzlib.jobs import Job -from ..onefuzzlib.pools import NodeTasks from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.task_event import TaskEvent from ..onefuzzlib.tasks.config import TaskConfigError, check_config from ..onefuzzlib.tasks.main import Task from ..onefuzzlib.user_credentials import parse_jwt_token +from ..onefuzzlib.workers.nodes import NodeTasks def post(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/timer_daily/__init__.py b/src/api-service/__app__/timer_daily/__init__.py index 57eed6631..ca6fe1d27 100644 --- a/src/api-service/__app__/timer_daily/__init__.py +++ b/src/api-service/__app__/timer_daily/__init__.py @@ -9,9 +9,9 @@ import azure.functions as func from onefuzztypes.enums import VmState from ..onefuzzlib.events import get_events -from ..onefuzzlib.pools import Scaleset from ..onefuzzlib.proxy import Proxy from ..onefuzzlib.webhooks import WebhookMessageLog +from ..onefuzzlib.workers.scalesets import Scaleset def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841 diff --git a/src/api-service/__app__/timer_workers/__init__.py b/src/api-service/__app__/timer_workers/__init__.py index 81e49a567..09bac76c0 100644 --- a/src/api-service/__app__/timer_workers/__init__.py +++ b/src/api-service/__app__/timer_workers/__init__.py @@ -11,7 +11,9 @@ from onefuzztypes.enums import NodeState, PoolState from ..onefuzzlib.autoscale import autoscale_pool from ..onefuzzlib.events import get_events from ..onefuzzlib.orm import process_state_updates -from ..onefuzzlib.pools import Node, Pool, Scaleset +from ..onefuzzlib.workers.nodes import Node +from ..onefuzzlib.workers.pools import Pool +from ..onefuzzlib.workers.scalesets import Scaleset def process_scaleset(scaleset: Scaleset) -> None: diff --git a/src/api-service/tests/test_autoscale.py b/src/api-service/tests/test_autoscale.py index 506b8fb9c..b56fc6164 100644 --- a/src/api-service/tests/test_autoscale.py +++ b/src/api-service/tests/test_autoscale.py @@ -12,8 +12,8 @@ from onefuzztypes.models import TaskConfig, TaskContainers, TaskDetails, TaskPoo from onefuzztypes.primitives import Container, PoolName from __app__.onefuzzlib.autoscale import autoscale_pool, get_vm_count -from __app__.onefuzzlib.pools import Pool from __app__.onefuzzlib.tasks.main import Task +from __app__.onefuzzlib.workers.pools import Pool class TestAutoscale(unittest.TestCase):