diff --git a/src/api-service/__app__/onefuzzlib/workers/nodes.py b/src/api-service/__app__/onefuzzlib/workers/nodes.py index f18e6a494..0a254000e 100644 --- a/src/api-service/__app__/onefuzzlib/workers/nodes.py +++ b/src/api-service/__app__/onefuzzlib/workers/nodes.py @@ -32,6 +32,7 @@ from ..azure.vmss import get_instance_id from ..events import send_event from ..orm import MappingIntStrAny, ORMMixin, QueryFilter from ..versions import is_minimum_version +from .shrink_queue import ShrinkQueue NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) @@ -242,9 +243,7 @@ class Node(BASE_NODE, ORMMixin): entry.delete() def could_shrink_scaleset(self) -> bool: - from .scalesets import ScalesetShrinkQueue - - if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink(): + if self.scaleset_id and ShrinkQueue(self.scaleset_id).should_shrink(): return True return False diff --git a/src/api-service/__app__/onefuzzlib/workers/scalesets.py b/src/api-service/__app__/onefuzzlib/workers/scalesets.py index c93c65b72..327c9302e 100644 --- a/src/api-service/__app__/onefuzzlib/workers/scalesets.py +++ b/src/api-service/__app__/onefuzzlib/workers/scalesets.py @@ -6,7 +6,7 @@ import datetime import logging from typing import Any, Dict, List, Optional, Tuple, Union -from uuid import UUID, uuid4 +from uuid import UUID from onefuzztypes.enums import ErrorCode, NodeState, PoolState, ScalesetState from onefuzztypes.events import ( @@ -19,20 +19,11 @@ from onefuzztypes.models import Error from onefuzztypes.models import Scaleset as BASE_SCALESET from onefuzztypes.models import ScalesetNodeState from onefuzztypes.primitives import PoolName, Region -from pydantic import BaseModel, Field from ..__version__ import __version__ from ..azure.auth import build_auth from ..azure.image import get_os from ..azure.network import Network -from ..azure.queue import ( - clear_queue, - create_queue, - delete_queue, - queue_object, - remove_first_message, -) -from ..azure.storage import StorageType from ..azure.vmss import ( UnableToUpdate, create_vmss, @@ -49,6 +40,7 @@ from ..events import send_event from ..extension import fuzz_extensions from ..orm import MappingIntStrAny, ORMMixin, QueryFilter from .nodes import Node +from .shrink_queue import ShrinkQueue NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1) NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7) @@ -155,7 +147,7 @@ class Scaleset(BASE_SCALESET, ORMMixin): logging.info(SCALESET_LOG_PREFIX + "init. scaleset_id:%s", self.scaleset_id) - ScalesetShrinkQueue(self.scaleset_id).create() + ShrinkQueue(self.scaleset_id).create() # Handle the race condition between a pool being deleted and a # scaleset being added to the pool. @@ -369,7 +361,7 @@ class Scaleset(BASE_SCALESET, ORMMixin): if node.delete_requested: to_delete.append(node) else: - if ScalesetShrinkQueue(self.scaleset_id).should_shrink(): + if ShrinkQueue(self.scaleset_id).should_shrink(): node.set_halt() to_delete.append(node) else: @@ -445,9 +437,8 @@ class Scaleset(BASE_SCALESET, ORMMixin): self.scaleset_id, to_remove, ) - queue = ScalesetShrinkQueue(self.scaleset_id) - for _ in range(to_remove): - queue.add_entry() + queue = ShrinkQueue(self.scaleset_id) + queue.set_size(to_remove) nodes = Node.search_states(scaleset_id=self.scaleset_id) for node in nodes: @@ -493,7 +484,7 @@ class Scaleset(BASE_SCALESET, ORMMixin): ) # reset the node delete queue - ScalesetShrinkQueue(self.scaleset_id).clear() + ShrinkQueue(self.scaleset_id).clear() # just in case, always ensure size is within max capacity self.size = min(self.size, self.max_size()) @@ -654,7 +645,7 @@ class Scaleset(BASE_SCALESET, ORMMixin): self.halt() def halt(self) -> None: - ScalesetShrinkQueue(self.scaleset_id).delete() + ShrinkQueue(self.scaleset_id).delete() for node in Node.search_states(scaleset_id=self.scaleset_id): logging.info( @@ -794,30 +785,3 @@ class Scaleset(BASE_SCALESET, ORMMixin): scaleset_id=self.scaleset_id, pool_name=self.pool_name, state=self.state ) ) - - -class ShrinkEntry(BaseModel): - shrink_id: UUID = Field(default_factory=uuid4) - - -class ScalesetShrinkQueue: - def __init__(self, scaleset_id: UUID): - self.scaleset_id = scaleset_id - - def queue_name(self) -> str: - return "to-shrink-%s" % self.scaleset_id.hex - - def clear(self) -> None: - clear_queue(self.queue_name(), StorageType.config) - - def create(self) -> None: - create_queue(self.queue_name(), StorageType.config) - - def delete(self) -> None: - delete_queue(self.queue_name(), StorageType.config) - - def add_entry(self) -> None: - queue_object(self.queue_name(), ShrinkEntry(), StorageType.config) - - def should_shrink(self) -> bool: - return remove_first_message(self.queue_name(), StorageType.config) diff --git a/src/api-service/__app__/onefuzzlib/workers/shrink_queue.py b/src/api-service/__app__/onefuzzlib/workers/shrink_queue.py new file mode 100644 index 000000000..28a7d8d75 --- /dev/null +++ b/src/api-service/__app__/onefuzzlib/workers/shrink_queue.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + +from ..azure.queue import ( + clear_queue, + create_queue, + delete_queue, + queue_object, + remove_first_message, +) +from ..azure.storage import StorageType + + +class ShrinkEntry(BaseModel): + shrink_id: UUID = Field(default_factory=uuid4) + + +class ShrinkQueue: + def __init__(self, base_id: UUID): + self.base_id = base_id + + def queue_name(self) -> str: + return "to-shrink-%s" % self.base_id.hex + + def clear(self) -> None: + clear_queue(self.queue_name(), StorageType.config) + + def create(self) -> None: + create_queue(self.queue_name(), StorageType.config) + + def delete(self) -> None: + delete_queue(self.queue_name(), StorageType.config) + + def add_entry(self) -> None: + queue_object(self.queue_name(), ShrinkEntry(), StorageType.config) + + def set_size(self, size: int) -> None: + self.clear() + for _ in range(size): + self.add_entry() + + def should_shrink(self) -> bool: + return remove_first_message(self.queue_name(), StorageType.config)