mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
refactor node state to fully put the agent in charge (#90)
This commit is contained in:
@ -88,6 +88,15 @@ def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceC
|
||||
return None
|
||||
|
||||
|
||||
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if queue:
|
||||
try:
|
||||
queue.clear_messages()
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def send_message(
|
||||
name: QueueNameType,
|
||||
message: bytes,
|
||||
@ -102,6 +111,19 @@ def send_message(
|
||||
pass
|
||||
|
||||
|
||||
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
|
||||
create_queue(name, account_id=account_id)
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if queue:
|
||||
try:
|
||||
for message in queue.receive_messages():
|
||||
queue.delete_message(message)
|
||||
return True
|
||||
except ResourceNotFoundError:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
A = TypeVar("A", bound=BaseModel)
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
|
||||
|
||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
logging.info("get instance IDs for scaleset: %s", name)
|
||||
logging.debug("get instance IDs for scaleset: %s", name)
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
OS,
|
||||
@ -15,7 +15,6 @@ from onefuzztypes.enums import (
|
||||
NodeState,
|
||||
PoolState,
|
||||
ScalesetState,
|
||||
TaskState,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Node as BASE_NODE
|
||||
@ -32,14 +31,21 @@ from onefuzztypes.models import (
|
||||
WorkUnitSummary,
|
||||
)
|
||||
from onefuzztypes.primitives import PoolName, Region
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .__version__ import __version__
|
||||
from .azure.auth import build_auth
|
||||
from .azure.creds import get_fuzz_storage
|
||||
from .azure.creds import get_func_storage, get_fuzz_storage
|
||||
from .azure.image import get_os
|
||||
from .azure.network import Network
|
||||
from .azure.queue import create_queue, delete_queue, peek_queue, queue_object
|
||||
from .azure.queue import (
|
||||
clear_queue,
|
||||
create_queue,
|
||||
delete_queue,
|
||||
peek_queue,
|
||||
queue_object,
|
||||
remove_first_message,
|
||||
)
|
||||
from .azure.vmss import (
|
||||
UnableToUpdate,
|
||||
create_vmss,
|
||||
@ -99,7 +105,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
# azure table query always return false when the column does not exist
|
||||
# We write the query this way to allow us to get the nodes where the
|
||||
# version is not defined as well as the nodes with a mismatched version
|
||||
version_query = "not (version ne '%s')" % __version__
|
||||
version_query = "not (version eq '%s')" % __version__
|
||||
return cls.search(query=query, raw_unchecked_filter=version_query)
|
||||
|
||||
@classmethod
|
||||
@ -154,16 +160,90 @@ class Node(BASE_NODE, ORMMixin):
|
||||
for node in nodes:
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
logging.info(
|
||||
"stopping task %s on machine_id:%s",
|
||||
task_id,
|
||||
"stopping machine_id:%s running task:%s",
|
||||
node.machine_id,
|
||||
task_id,
|
||||
)
|
||||
node.state = NodeState.done
|
||||
node.save()
|
||||
node.stop()
|
||||
|
||||
def mark_tasks_stopped_early(self) -> None:
|
||||
from .tasks.main import Task
|
||||
|
||||
for entry in NodeTasks.get_by_machine_id(self.machine_id):
|
||||
task = Task.get_by_task_id(entry.task_id)
|
||||
if isinstance(task, Task):
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=["node reimaged during task execution"],
|
||||
)
|
||||
)
|
||||
entry.delete()
|
||||
|
||||
def could_shrink_scaleset(self) -> bool:
|
||||
if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink():
|
||||
return True
|
||||
return False
|
||||
|
||||
def can_process_new_work(self) -> bool:
|
||||
if self.is_outdated():
|
||||
logging.info(
|
||||
"can_schedule old version machine_id:%s version:%s",
|
||||
self.machine_id,
|
||||
self.version,
|
||||
)
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
if self.delete_requested or self.reimage_requested:
|
||||
logging.info(
|
||||
"can_schedule should be recycled. machine_id:%s", self.machine_id
|
||||
)
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
if self.could_shrink_scaleset():
|
||||
self.set_halt()
|
||||
logging.info("node scheduled to shrink. machine_id:%s", self.machine_id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def is_outdated(self) -> bool:
|
||||
return self.version != __version__
|
||||
|
||||
def send_message(self, message: NodeCommand) -> None:
|
||||
stop_message = NodeMessage(
|
||||
agent_id=self.machine_id,
|
||||
message=message,
|
||||
)
|
||||
stop_message.save()
|
||||
|
||||
def to_reimage(self, done: bool = False) -> None:
|
||||
if done:
|
||||
if self.state not in NodeState.ready_for_reset():
|
||||
self.state = NodeState.done
|
||||
|
||||
if not self.reimage_requested and not self.delete_requested:
|
||||
logging.info("setting reimage_requested: %s", self.machine_id)
|
||||
self.reimage_requested = True
|
||||
self.save()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.to_reimage()
|
||||
self.send_message(NodeCommand(stop=StopNodeCommand()))
|
||||
|
||||
def set_shutdown(self) -> None:
|
||||
# don't give out more work to the node, but let it finish existing work
|
||||
logging.info("setting delete_requested: %s", self.machine_id)
|
||||
self.delete_requested = True
|
||||
self.save()
|
||||
|
||||
def set_halt(self) -> None:
|
||||
""" Tell the node to stop everything. """
|
||||
self.set_shutdown()
|
||||
self.stop()
|
||||
|
||||
|
||||
class NodeTasks(BASE_NODE_TASK, ORMMixin):
|
||||
@classmethod
|
||||
@ -384,8 +464,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
scaleset.save()
|
||||
|
||||
for node in nodes:
|
||||
node.state = NodeState.shutdown
|
||||
node.save()
|
||||
node.set_shutdown()
|
||||
|
||||
self.save()
|
||||
|
||||
@ -405,13 +484,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
scaleset.save()
|
||||
|
||||
for node in nodes:
|
||||
logging.info(
|
||||
"deleting node from pool: %s (%s) - machine_id:%s",
|
||||
self.pool_id,
|
||||
self.name,
|
||||
node.machine_id,
|
||||
)
|
||||
node.delete()
|
||||
node.set_halt()
|
||||
|
||||
self.save()
|
||||
|
||||
@ -496,6 +569,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
def init(self) -> None:
|
||||
logging.info("scaleset init: %s", self.scaleset_id)
|
||||
|
||||
ScalesetShrinkQueue(self.scaleset_id).create()
|
||||
|
||||
# Handle the race condition between a pool being deleted and a
|
||||
# scaleset being added to the pool.
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
@ -596,49 +671,62 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
# result = 'did I modify the scaleset in azure'
|
||||
def cleanup_nodes(self) -> bool:
|
||||
if self.state == ScalesetState.halt:
|
||||
logging.info("halting scaleset: %s", self.scaleset_id)
|
||||
self.halt()
|
||||
return True
|
||||
|
||||
to_reimage = []
|
||||
to_delete = []
|
||||
|
||||
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
|
||||
for node in outdated:
|
||||
logging.info(
|
||||
"node is outdated: %s - node_version:%s api_version:%s",
|
||||
node.machine_id,
|
||||
node.version,
|
||||
__version__,
|
||||
)
|
||||
if node.version == "1.0.0":
|
||||
node.state = NodeState.done
|
||||
to_reimage.append(node)
|
||||
else:
|
||||
node.to_reimage()
|
||||
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
|
||||
)
|
||||
|
||||
outdated = Node.search_outdated(
|
||||
scaleset_id=self.scaleset_id,
|
||||
states=[NodeState.free],
|
||||
)
|
||||
|
||||
if not (nodes or outdated):
|
||||
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
|
||||
if not outdated and not nodes:
|
||||
logging.info("no nodes need updating: %s", self.scaleset_id)
|
||||
return False
|
||||
|
||||
to_delete = []
|
||||
to_reimage = []
|
||||
|
||||
for node in outdated:
|
||||
if node.version == "1.0.0":
|
||||
to_reimage.append(node)
|
||||
else:
|
||||
stop_message = NodeMessage(
|
||||
agent_id=node.machine_id,
|
||||
message=NodeCommand(stop=StopNodeCommand()),
|
||||
)
|
||||
stop_message.save()
|
||||
# ground truth of existing nodes
|
||||
azure_nodes = list_instance_ids(self.scaleset_id)
|
||||
|
||||
for node in nodes:
|
||||
# delete nodes that are not waiting on the scaleset GC
|
||||
if not node.scaleset_node_exists():
|
||||
if node.machine_id not in azure_nodes:
|
||||
logging.info(
|
||||
"no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id
|
||||
)
|
||||
node.delete()
|
||||
elif node.state in [NodeState.shutdown, NodeState.halt]:
|
||||
elif node.delete_requested:
|
||||
to_delete.append(node)
|
||||
else:
|
||||
to_reimage.append(node)
|
||||
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
|
||||
node.set_halt()
|
||||
to_delete.append(node)
|
||||
else:
|
||||
to_reimage.append(node)
|
||||
|
||||
# Perform operations until they fail due to scaleset getting locked
|
||||
try:
|
||||
if to_delete:
|
||||
logging.info(
|
||||
"deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete)
|
||||
)
|
||||
self.delete_nodes(to_delete)
|
||||
for node in to_delete:
|
||||
node.set_halt()
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
|
||||
@ -646,69 +734,65 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.reimage_nodes(to_reimage)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset update already in progress: %s", self.scaleset_id)
|
||||
|
||||
return True
|
||||
|
||||
def resize(self) -> None:
|
||||
logging.info(
|
||||
"scaleset resize: %s - current: %s new: %s",
|
||||
self.scaleset_id,
|
||||
self.size,
|
||||
self.new_size,
|
||||
)
|
||||
|
||||
# no work needed to resize
|
||||
if self.new_size is None:
|
||||
def _resize_equal(self) -> None:
|
||||
# NOTE: this is the only place we reset to the 'running' state.
|
||||
# This ensures that our idea of scaleset size agrees with Azure
|
||||
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
|
||||
if node_count == self.size:
|
||||
logging.info("resize finished: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.running
|
||||
self.save()
|
||||
return
|
||||
else:
|
||||
logging.info(
|
||||
"resize is finished, waiting for nodes to check in: "
|
||||
"%s (%d of %d nodes checked in)",
|
||||
self.scaleset_id,
|
||||
node_count,
|
||||
self.size,
|
||||
)
|
||||
return
|
||||
|
||||
def _resize_grow(self) -> None:
|
||||
try:
|
||||
resize_vmss(self.scaleset_id, self.size)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset is mid-operation already")
|
||||
return
|
||||
|
||||
def _resize_shrink(self, to_remove: int) -> None:
|
||||
queue = ScalesetShrinkQueue(self.scaleset_id)
|
||||
for _ in range(to_remove):
|
||||
queue.add_entry()
|
||||
|
||||
def resize(self) -> None:
|
||||
# no longer needing to resize
|
||||
if self.state != ScalesetState.resize:
|
||||
return
|
||||
|
||||
logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size)
|
||||
|
||||
# reset the node delete queue
|
||||
ScalesetShrinkQueue(self.scaleset_id).clear()
|
||||
|
||||
# just in case, always ensure size is within max capacity
|
||||
self.new_size = min(self.new_size, self.max_size())
|
||||
self.size = min(self.size, self.max_size())
|
||||
|
||||
# Treat Azure knowledge of the size of the scaleset as "ground truth"
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
if size is None:
|
||||
logging.info("scaleset is unavailable. Re-queuing")
|
||||
self.save()
|
||||
logging.info("scaleset is unavailable: %s", self.scaleset_id)
|
||||
return
|
||||
|
||||
if size == self.new_size:
|
||||
# NOTE: this is the only place we reset to the 'running' state.
|
||||
# This ensures that our idea of scaleset size agrees with Azure
|
||||
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
|
||||
if node_count == self.size:
|
||||
logging.info("resize finished: %s", self.scaleset_id)
|
||||
self.new_size = None
|
||||
self.state = ScalesetState.running
|
||||
else:
|
||||
logging.info(
|
||||
"resize is finished, waiting for nodes to check in: "
|
||||
"%s (%d of %d nodes checked in)",
|
||||
self.scaleset_id,
|
||||
node_count,
|
||||
self.size,
|
||||
)
|
||||
# When adding capacity, call the resize API directly
|
||||
elif self.new_size > self.size:
|
||||
try:
|
||||
resize_vmss(self.scaleset_id, self.new_size)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset is mid-operation already")
|
||||
# Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node
|
||||
# to pick up that the scaleset is too big upon task completion
|
||||
if size == self.size:
|
||||
self._resize_equal()
|
||||
elif self.size > size:
|
||||
self._resize_grow()
|
||||
else:
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free]
|
||||
)
|
||||
for node in nodes:
|
||||
if size > self.new_size:
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
size -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
self.save()
|
||||
self._resize_shrink(size - self.size)
|
||||
|
||||
def delete_nodes(self, nodes: List[Node]) -> None:
|
||||
if not nodes:
|
||||
@ -723,31 +807,12 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
|
||||
logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
|
||||
delete_vmss_nodes(self.scaleset_id, machine_ids)
|
||||
self.size -= len(machine_ids)
|
||||
self.save()
|
||||
|
||||
def reimage_nodes(self, nodes: List[Node]) -> None:
|
||||
from .tasks.main import Task
|
||||
|
||||
if not nodes:
|
||||
logging.debug("no nodes to reimage")
|
||||
return
|
||||
|
||||
for node in nodes:
|
||||
for entry in NodeTasks.get_by_machine_id(node.machine_id):
|
||||
task = Task.get_by_task_id(entry.task_id)
|
||||
if isinstance(task, Task):
|
||||
if task.state in [TaskState.stopping, TaskState.stopped]:
|
||||
continue
|
||||
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=["node reimaged during task execution"],
|
||||
)
|
||||
)
|
||||
entry.delete()
|
||||
|
||||
if self.state == ScalesetState.shutdown:
|
||||
self.delete_nodes(nodes)
|
||||
return
|
||||
@ -766,28 +831,28 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logging.info("scaleset shutdown: %s", self.scaleset_id)
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
|
||||
if size is None or size == 0:
|
||||
self.state = ScalesetState.halt
|
||||
self.halt()
|
||||
return
|
||||
self.save()
|
||||
|
||||
def halt(self) -> None:
|
||||
self.state = ScalesetState.halt
|
||||
ScalesetShrinkQueue(self.scaleset_id).delete()
|
||||
|
||||
for node in Node.search_states(scaleset_id=self.scaleset_id):
|
||||
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
|
||||
node.delete()
|
||||
|
||||
vmss = get_vmss(self.scaleset_id)
|
||||
if vmss is None:
|
||||
logging.info("scaleset deleted: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.halt
|
||||
self.delete()
|
||||
else:
|
||||
if vmss:
|
||||
logging.info("scaleset deleting: %s", self.scaleset_id)
|
||||
delete_vmss(self.scaleset_id)
|
||||
self.save()
|
||||
else:
|
||||
logging.info("scaleset deleted: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.halt
|
||||
self.delete()
|
||||
|
||||
def max_size(self) -> int:
|
||||
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
|
||||
@ -858,3 +923,30 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("pool_name", "scaleset_id")
|
||||
|
||||
|
||||
class ShrinkEntry(BaseModel):
|
||||
shrink_id: UUID = Field(default_factory=uuid4)
|
||||
|
||||
|
||||
class ScalesetShrinkQueue:
|
||||
def __init__(self, scaleset_id: UUID):
|
||||
self.scaleset_id = scaleset_id
|
||||
|
||||
def queue_name(self) -> str:
|
||||
return "to-shrink-%s" % self.scaleset_id.hex
|
||||
|
||||
def clear(self) -> None:
|
||||
clear_queue(self.queue_name(), account_id=get_func_storage())
|
||||
|
||||
def create(self) -> None:
|
||||
create_queue(self.queue_name(), account_id=get_func_storage())
|
||||
|
||||
def delete(self) -> None:
|
||||
delete_queue(self.queue_name(), account_id=get_func_storage())
|
||||
|
||||
def add_entry(self) -> None:
|
||||
queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage())
|
||||
|
||||
def should_shrink(self) -> bool:
|
||||
return remove_first_message(self.queue_name(), account_id=get_func_storage())
|
||||
|
@ -108,7 +108,6 @@ class Task(BASE_TASK, ORMMixin):
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
# TODO: we need to tell every node currently working on this task to stop
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
|
||||
self.state = TaskState.stopping
|
||||
@ -154,6 +153,11 @@ class Task(BASE_TASK, ORMMixin):
|
||||
task = tasks[0]
|
||||
return task
|
||||
|
||||
def mark_stopping(self) -> None:
|
||||
if self.state not in [TaskState.stopped, TaskState.stopping]:
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
|
||||
def mark_failed(self, error: Error) -> None:
|
||||
if self.state in [TaskState.stopped, TaskState.stopping]:
|
||||
logging.debug(
|
||||
|
Reference in New Issue
Block a user