refactor node state to fully put the agent in charge (#90)

This commit is contained in:
bmc-msft
2020-10-03 02:43:04 -04:00
committed by GitHub
parent a088e72299
commit e308a4ae1e
12 changed files with 326 additions and 198 deletions

View File

@ -88,6 +88,15 @@ def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceC
return None
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
queue = get_queue(name, account_id=account_id)
if queue:
try:
queue.clear_messages()
except ResourceNotFoundError:
return None
def send_message(
name: QueueNameType,
message: bytes,
@ -102,6 +111,19 @@ def send_message(
pass
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
create_queue(name, account_id=account_id)
queue = get_queue(name, account_id=account_id)
if queue:
try:
for message in queue.receive_messages():
queue.delete_message(message)
return True
except ResourceNotFoundError:
return False
return False
A = TypeVar("A", bound=BaseModel)

View File

@ -76,7 +76,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
logging.info("get instance IDs for scaleset: %s", name)
logging.debug("get instance IDs for scaleset: %s", name)
resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient)

View File

@ -6,7 +6,7 @@
import datetime
import logging
from typing import Dict, List, Optional, Tuple, Union
from uuid import UUID
from uuid import UUID, uuid4
from onefuzztypes.enums import (
OS,
@ -15,7 +15,6 @@ from onefuzztypes.enums import (
NodeState,
PoolState,
ScalesetState,
TaskState,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Node as BASE_NODE
@ -32,14 +31,21 @@ from onefuzztypes.models import (
WorkUnitSummary,
)
from onefuzztypes.primitives import PoolName, Region
from pydantic import Field
from pydantic import BaseModel, Field
from .__version__ import __version__
from .azure.auth import build_auth
from .azure.creds import get_fuzz_storage
from .azure.creds import get_func_storage, get_fuzz_storage
from .azure.image import get_os
from .azure.network import Network
from .azure.queue import create_queue, delete_queue, peek_queue, queue_object
from .azure.queue import (
clear_queue,
create_queue,
delete_queue,
peek_queue,
queue_object,
remove_first_message,
)
from .azure.vmss import (
UnableToUpdate,
create_vmss,
@ -99,7 +105,7 @@ class Node(BASE_NODE, ORMMixin):
# azure table query always return false when the column does not exist
# We write the query this way to allow us to get the nodes where the
# version is not defined as well as the nodes with a mismatched version
version_query = "not (version ne '%s')" % __version__
version_query = "not (version eq '%s')" % __version__
return cls.search(query=query, raw_unchecked_filter=version_query)
@classmethod
@ -154,16 +160,90 @@ class Node(BASE_NODE, ORMMixin):
for node in nodes:
if node.state not in NodeState.ready_for_reset():
logging.info(
"stopping task %s on machine_id:%s",
task_id,
"stopping machine_id:%s running task:%s",
node.machine_id,
task_id,
)
node.state = NodeState.done
node.save()
node.stop()
def mark_tasks_stopped_early(self) -> None:
from .tasks.main import Task
for entry in NodeTasks.get_by_machine_id(self.machine_id):
task = Task.get_by_task_id(entry.task_id)
if isinstance(task, Task):
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=["node reimaged during task execution"],
)
)
entry.delete()
def could_shrink_scaleset(self) -> bool:
if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink():
return True
return False
def can_process_new_work(self) -> bool:
if self.is_outdated():
logging.info(
"can_schedule old version machine_id:%s version:%s",
self.machine_id,
self.version,
)
self.stop()
return False
if self.delete_requested or self.reimage_requested:
logging.info(
"can_schedule should be recycled. machine_id:%s", self.machine_id
)
self.stop()
return False
if self.could_shrink_scaleset():
self.set_halt()
logging.info("node scheduled to shrink. machine_id:%s", self.machine_id)
return False
return True
def is_outdated(self) -> bool:
return self.version != __version__
def send_message(self, message: NodeCommand) -> None:
stop_message = NodeMessage(
agent_id=self.machine_id,
message=message,
)
stop_message.save()
def to_reimage(self, done: bool = False) -> None:
if done:
if self.state not in NodeState.ready_for_reset():
self.state = NodeState.done
if not self.reimage_requested and not self.delete_requested:
logging.info("setting reimage_requested: %s", self.machine_id)
self.reimage_requested = True
self.save()
def stop(self) -> None:
self.to_reimage()
self.send_message(NodeCommand(stop=StopNodeCommand()))
def set_shutdown(self) -> None:
# don't give out more work to the node, but let it finish existing work
logging.info("setting delete_requested: %s", self.machine_id)
self.delete_requested = True
self.save()
def set_halt(self) -> None:
""" Tell the node to stop everything. """
self.set_shutdown()
self.stop()
class NodeTasks(BASE_NODE_TASK, ORMMixin):
@classmethod
@ -384,8 +464,7 @@ class Pool(BASE_POOL, ORMMixin):
scaleset.save()
for node in nodes:
node.state = NodeState.shutdown
node.save()
node.set_shutdown()
self.save()
@ -405,13 +484,7 @@ class Pool(BASE_POOL, ORMMixin):
scaleset.save()
for node in nodes:
logging.info(
"deleting node from pool: %s (%s) - machine_id:%s",
self.pool_id,
self.name,
node.machine_id,
)
node.delete()
node.set_halt()
self.save()
@ -496,6 +569,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
def init(self) -> None:
logging.info("scaleset init: %s", self.scaleset_id)
ScalesetShrinkQueue(self.scaleset_id).create()
# Handle the race condition between a pool being deleted and a
# scaleset being added to the pool.
pool = Pool.get_by_name(self.pool_name)
@ -596,49 +671,62 @@ class Scaleset(BASE_SCALESET, ORMMixin):
# result = 'did I modify the scaleset in azure'
def cleanup_nodes(self) -> bool:
if self.state == ScalesetState.halt:
logging.info("halting scaleset: %s", self.scaleset_id)
self.halt()
return True
to_reimage = []
to_delete = []
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.state = NodeState.done
to_reimage.append(node)
else:
node.to_reimage()
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
)
outdated = Node.search_outdated(
scaleset_id=self.scaleset_id,
states=[NodeState.free],
)
if not (nodes or outdated):
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
if not outdated and not nodes:
logging.info("no nodes need updating: %s", self.scaleset_id)
return False
to_delete = []
to_reimage = []
for node in outdated:
if node.version == "1.0.0":
to_reimage.append(node)
else:
stop_message = NodeMessage(
agent_id=node.machine_id,
message=NodeCommand(stop=StopNodeCommand()),
)
stop_message.save()
# ground truth of existing nodes
azure_nodes = list_instance_ids(self.scaleset_id)
for node in nodes:
# delete nodes that are not waiting on the scaleset GC
if not node.scaleset_node_exists():
if node.machine_id not in azure_nodes:
logging.info(
"no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id
)
node.delete()
elif node.state in [NodeState.shutdown, NodeState.halt]:
elif node.delete_requested:
to_delete.append(node)
else:
to_reimage.append(node)
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
node.set_halt()
to_delete.append(node)
else:
to_reimage.append(node)
# Perform operations until they fail due to scaleset getting locked
try:
if to_delete:
logging.info(
"deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete)
)
self.delete_nodes(to_delete)
for node in to_delete:
node.set_halt()
node.state = NodeState.halt
node.save()
@ -646,69 +734,65 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.reimage_nodes(to_reimage)
except UnableToUpdate:
logging.info("scaleset update already in progress: %s", self.scaleset_id)
return True
def resize(self) -> None:
logging.info(
"scaleset resize: %s - current: %s new: %s",
self.scaleset_id,
self.size,
self.new_size,
)
# no work needed to resize
if self.new_size is None:
def _resize_equal(self) -> None:
# NOTE: this is the only place we reset to the 'running' state.
# This ensures that our idea of scaleset size agrees with Azure
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
if node_count == self.size:
logging.info("resize finished: %s", self.scaleset_id)
self.state = ScalesetState.running
self.save()
return
else:
logging.info(
"resize is finished, waiting for nodes to check in: "
"%s (%d of %d nodes checked in)",
self.scaleset_id,
node_count,
self.size,
)
return
def _resize_grow(self) -> None:
try:
resize_vmss(self.scaleset_id, self.size)
except UnableToUpdate:
logging.info("scaleset is mid-operation already")
return
def _resize_shrink(self, to_remove: int) -> None:
queue = ScalesetShrinkQueue(self.scaleset_id)
for _ in range(to_remove):
queue.add_entry()
def resize(self) -> None:
# no longer needing to resize
if self.state != ScalesetState.resize:
return
logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size)
# reset the node delete queue
ScalesetShrinkQueue(self.scaleset_id).clear()
# just in case, always ensure size is within max capacity
self.new_size = min(self.new_size, self.max_size())
self.size = min(self.size, self.max_size())
# Treat Azure knowledge of the size of the scaleset as "ground truth"
size = get_vmss_size(self.scaleset_id)
if size is None:
logging.info("scaleset is unavailable. Re-queuing")
self.save()
logging.info("scaleset is unavailable: %s", self.scaleset_id)
return
if size == self.new_size:
# NOTE: this is the only place we reset to the 'running' state.
# This ensures that our idea of scaleset size agrees with Azure
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
if node_count == self.size:
logging.info("resize finished: %s", self.scaleset_id)
self.new_size = None
self.state = ScalesetState.running
else:
logging.info(
"resize is finished, waiting for nodes to check in: "
"%s (%d of %d nodes checked in)",
self.scaleset_id,
node_count,
self.size,
)
# When adding capacity, call the resize API directly
elif self.new_size > self.size:
try:
resize_vmss(self.scaleset_id, self.new_size)
except UnableToUpdate:
logging.info("scaleset is mid-operation already")
# Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node
# to pick up that the scaleset is too big upon task completion
if size == self.size:
self._resize_equal()
elif self.size > size:
self._resize_grow()
else:
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free]
)
for node in nodes:
if size > self.new_size:
node.state = NodeState.halt
node.save()
size -= 1
else:
break
self.save()
self._resize_shrink(size - self.size)
def delete_nodes(self, nodes: List[Node]) -> None:
if not nodes:
@ -723,31 +807,12 @@ class Scaleset(BASE_SCALESET, ORMMixin):
logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
delete_vmss_nodes(self.scaleset_id, machine_ids)
self.size -= len(machine_ids)
self.save()
def reimage_nodes(self, nodes: List[Node]) -> None:
from .tasks.main import Task
if not nodes:
logging.debug("no nodes to reimage")
return
for node in nodes:
for entry in NodeTasks.get_by_machine_id(node.machine_id):
task = Task.get_by_task_id(entry.task_id)
if isinstance(task, Task):
if task.state in [TaskState.stopping, TaskState.stopped]:
continue
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=["node reimaged during task execution"],
)
)
entry.delete()
if self.state == ScalesetState.shutdown:
self.delete_nodes(nodes)
return
@ -766,28 +831,28 @@ class Scaleset(BASE_SCALESET, ORMMixin):
)
def shutdown(self) -> None:
logging.info("scaleset shutdown: %s", self.scaleset_id)
size = get_vmss_size(self.scaleset_id)
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
if size is None or size == 0:
self.state = ScalesetState.halt
self.halt()
return
self.save()
def halt(self) -> None:
self.state = ScalesetState.halt
ScalesetShrinkQueue(self.scaleset_id).delete()
for node in Node.search_states(scaleset_id=self.scaleset_id):
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
node.delete()
vmss = get_vmss(self.scaleset_id)
if vmss is None:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
else:
if vmss:
logging.info("scaleset deleting: %s", self.scaleset_id)
delete_vmss(self.scaleset_id)
self.save()
else:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
def max_size(self) -> int:
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
@ -858,3 +923,30 @@ class Scaleset(BASE_SCALESET, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "scaleset_id")
class ShrinkEntry(BaseModel):
shrink_id: UUID = Field(default_factory=uuid4)
class ScalesetShrinkQueue:
def __init__(self, scaleset_id: UUID):
self.scaleset_id = scaleset_id
def queue_name(self) -> str:
return "to-shrink-%s" % self.scaleset_id.hex
def clear(self) -> None:
clear_queue(self.queue_name(), account_id=get_func_storage())
def create(self) -> None:
create_queue(self.queue_name(), account_id=get_func_storage())
def delete(self) -> None:
delete_queue(self.queue_name(), account_id=get_func_storage())
def add_entry(self) -> None:
queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage())
def should_shrink(self) -> bool:
return remove_first_message(self.queue_name(), account_id=get_func_storage())

View File

@ -108,7 +108,6 @@ class Task(BASE_TASK, ORMMixin):
self.save()
def stopping(self) -> None:
# TODO: we need to tell every node currently working on this task to stop
# TODO: we need to 'unschedule' this task from the existing pools
self.state = TaskState.stopping
@ -154,6 +153,11 @@ class Task(BASE_TASK, ORMMixin):
task = tasks[0]
return task
def mark_stopping(self) -> None:
if self.state not in [TaskState.stopped, TaskState.stopping]:
self.state = TaskState.stopping
self.save()
def mark_failed(self, error: Error) -> None:
if self.state in [TaskState.stopped, TaskState.stopping]:
logging.debug(