mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-14 11:08:06 +00:00
refactor node state to fully put the agent in charge (#90)
This commit is contained in:
@ -3,16 +3,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import Error, NodeCommand, StopNodeCommand
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import CanScheduleRequest
|
||||
from onefuzztypes.responses import CanSchedule
|
||||
|
||||
from ..onefuzzlib.agent_authorization import verify_token
|
||||
from ..onefuzzlib.pools import Node, NodeMessage
|
||||
from ..onefuzzlib.pools import Node
|
||||
from ..onefuzzlib.request import not_ok, ok, parse_request
|
||||
from ..onefuzzlib.tasks.main import Task
|
||||
|
||||
@ -31,23 +29,13 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
|
||||
allowed = True
|
||||
work_stopped = False
|
||||
if node.is_outdated():
|
||||
logging.info(
|
||||
"received can_schedule request from outdated node '%s' version '%s'",
|
||||
node.machine_id,
|
||||
node.version,
|
||||
)
|
||||
|
||||
if not node.can_process_new_work():
|
||||
allowed = False
|
||||
stop_message = NodeMessage(
|
||||
agent_id=node.machine_id,
|
||||
message=NodeCommand(stop=StopNodeCommand()),
|
||||
)
|
||||
stop_message.save()
|
||||
|
||||
task = Task.get_by_task_id(request.task_id)
|
||||
|
||||
work_stopped = isinstance(task, Error) or (task.state != TaskState.scheduled)
|
||||
|
||||
work_stopped = isinstance(task, Error) or task.state in TaskState.shutting_down()
|
||||
if work_stopped:
|
||||
allowed = False
|
||||
|
||||
|
@ -47,11 +47,28 @@ def get_node_checked(machine_id: UUID) -> Node:
|
||||
def on_state_update(
|
||||
machine_id: UUID,
|
||||
state_update: NodeStateUpdate,
|
||||
) -> func.HttpResponse:
|
||||
) -> None:
|
||||
state = state_update.state
|
||||
node = get_node_checked(machine_id)
|
||||
|
||||
if state == NodeState.init or node.state not in NodeState.ready_for_reset():
|
||||
if state == NodeState.free:
|
||||
if node.reimage_requested or node.delete_requested:
|
||||
logging.info("stopping free node with reset flags: %s", node.machine_id)
|
||||
node.stop()
|
||||
return
|
||||
|
||||
if node.could_shrink_scaleset():
|
||||
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
|
||||
node.set_halt()
|
||||
return
|
||||
|
||||
if state == NodeState.init:
|
||||
if node.delete_requested:
|
||||
node.stop()
|
||||
return
|
||||
node.reimage_requested = False
|
||||
node.save()
|
||||
elif node.state not in NodeState.ready_for_reset():
|
||||
if node.state != state:
|
||||
node.state = state
|
||||
node.save()
|
||||
@ -91,26 +108,28 @@ def on_state_update(
|
||||
)
|
||||
node_task.save()
|
||||
elif state == NodeState.done:
|
||||
# if tasks are running on the node when it reports as Done
|
||||
# those are stopped early
|
||||
node.mark_tasks_stopped_early()
|
||||
|
||||
# Model-validated.
|
||||
#
|
||||
# This field will be required in the future.
|
||||
# For now, it is optional for back compat.
|
||||
done_data = cast(Optional[NodeDoneEventData], state_update.data)
|
||||
|
||||
if done_data:
|
||||
# TODO: do something with this done data
|
||||
if done_data.error:
|
||||
logging.error(
|
||||
"node `done` with error: machine_id = %s, data = %s",
|
||||
"node 'done' with error: machine_id:%s, data:%s",
|
||||
machine_id,
|
||||
done_data,
|
||||
)
|
||||
else:
|
||||
logging.info("ignoring state updates from the node: %s: %s", machine_id, state)
|
||||
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
|
||||
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None:
|
||||
if event.running:
|
||||
task_id = event.running.task_id
|
||||
elif event.done:
|
||||
@ -129,37 +148,32 @@ def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
|
||||
task.state = TaskState.running
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
node.state = NodeState.busy
|
||||
node.save()
|
||||
node_task.save()
|
||||
|
||||
# Start the clock for the task if it wasn't started already
|
||||
# (as happens in 1.0.0 agents)
|
||||
task.on_start()
|
||||
elif event.done:
|
||||
# Only record exit status if the task isn't already shutting down.
|
||||
#
|
||||
# It's ok for the agent to fail because resources vanish out from underneath
|
||||
# it during deletion.
|
||||
if task.state not in TaskState.shutting_down():
|
||||
exit_status = event.done.exit_status
|
||||
|
||||
if not exit_status.success:
|
||||
logging.error("task failed: status = %s", exit_status)
|
||||
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=[
|
||||
"task failed. exit_status = %s" % exit_status,
|
||||
event.done.stdout,
|
||||
event.done.stderr,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
task.state = TaskState.stopping
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
node.state = NodeState.done
|
||||
node_task.delete()
|
||||
|
||||
exit_status = event.done.exit_status
|
||||
if not exit_status.success:
|
||||
logging.error("task failed. status:%s", exit_status)
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=[
|
||||
"task failed. exit_status:%s" % exit_status,
|
||||
event.done.stdout,
|
||||
event.done.stderr,
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
task.mark_stopping()
|
||||
|
||||
node.to_reimage(done=True)
|
||||
else:
|
||||
err = Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
@ -168,11 +182,9 @@ def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
|
||||
raise RequestException(err)
|
||||
|
||||
task.save()
|
||||
node.save()
|
||||
|
||||
task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event)
|
||||
task_event.save()
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
@ -197,9 +209,11 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
return not_ok(err, context=ERROR_CONTEXT)
|
||||
|
||||
if event.state_update:
|
||||
return on_state_update(envelope.machine_id, event.state_update)
|
||||
on_state_update(envelope.machine_id, event.state_update)
|
||||
return ok(BoolResult(result=True))
|
||||
elif event.worker_event:
|
||||
return on_worker_event(envelope.machine_id, event.worker_event)
|
||||
on_worker_event(envelope.machine_id, event.worker_event)
|
||||
return ok(BoolResult(result=True))
|
||||
else:
|
||||
err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
|
||||
return not_ok(err, context=ERROR_CONTEXT)
|
||||
|
@ -6,7 +6,7 @@
|
||||
from uuid import UUID
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.enums import ErrorCode, NodeState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
|
||||
from onefuzztypes.responses import AgentRegistration
|
||||
@ -78,7 +78,6 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
registration_request = parse_uri(AgentRegistrationPost, req)
|
||||
if isinstance(registration_request, Error):
|
||||
return not_ok(registration_request, context="agent registration")
|
||||
agent_node = Node.get_by_machine_id(registration_request.machine_id)
|
||||
|
||||
pool = Pool.get_by_name(registration_request.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
@ -90,20 +89,23 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
|
||||
context="agent registration",
|
||||
)
|
||||
|
||||
if agent_node is None:
|
||||
agent_node = Node(
|
||||
node = Node.get_by_machine_id(registration_request.machine_id)
|
||||
if node:
|
||||
if node.version != registration_request.version:
|
||||
NodeMessage.clear_messages(node.machine_id)
|
||||
node.version = registration_request.version
|
||||
node.reimage_requested = False
|
||||
node.state = NodeState.init
|
||||
else:
|
||||
node = Node(
|
||||
pool_name=registration_request.pool_name,
|
||||
machine_id=registration_request.machine_id,
|
||||
scaleset_id=registration_request.scaleset_id,
|
||||
version=registration_request.version,
|
||||
)
|
||||
agent_node.save()
|
||||
elif agent_node.version.lower != registration_request.version:
|
||||
NodeMessage.clear_messages(agent_node.machine_id)
|
||||
agent_node.version = registration_request.version
|
||||
agent_node.save()
|
||||
node.save()
|
||||
|
||||
return create_registration_response(agent_node.machine_id, pool)
|
||||
return create_registration_response(node.machine_id, pool)
|
||||
|
||||
|
||||
def main(req: func.HttpRequest) -> func.HttpResponse:
|
||||
|
@ -4,7 +4,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import azure.functions as func
|
||||
from onefuzztypes.enums import ErrorCode, NodeState
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.requests import NodeGet, NodeSearch
|
||||
from onefuzztypes.responses import BoolResult
|
||||
@ -54,8 +54,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
context=request.machine_id,
|
||||
)
|
||||
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
node.set_halt()
|
||||
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
@ -72,9 +71,7 @@ def patch(req: func.HttpRequest) -> func.HttpResponse:
|
||||
context=request.machine_id,
|
||||
)
|
||||
|
||||
node.state = NodeState.done
|
||||
node.save()
|
||||
|
||||
node.stop()
|
||||
return ok(BoolResult(result=True))
|
||||
|
||||
|
||||
|
@ -88,6 +88,15 @@ def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceC
|
||||
return None
|
||||
|
||||
|
||||
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if queue:
|
||||
try:
|
||||
queue.clear_messages()
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def send_message(
|
||||
name: QueueNameType,
|
||||
message: bytes,
|
||||
@ -102,6 +111,19 @@ def send_message(
|
||||
pass
|
||||
|
||||
|
||||
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
|
||||
create_queue(name, account_id=account_id)
|
||||
queue = get_queue(name, account_id=account_id)
|
||||
if queue:
|
||||
try:
|
||||
for message in queue.receive_messages():
|
||||
queue.delete_message(message)
|
||||
return True
|
||||
except ResourceNotFoundError:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
A = TypeVar("A", bound=BaseModel)
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
|
||||
|
||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
logging.info("get instance IDs for scaleset: %s", name)
|
||||
logging.debug("get instance IDs for scaleset: %s", name)
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = mgmt_client_factory(ComputeManagementClient)
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from onefuzztypes.enums import (
|
||||
OS,
|
||||
@ -15,7 +15,6 @@ from onefuzztypes.enums import (
|
||||
NodeState,
|
||||
PoolState,
|
||||
ScalesetState,
|
||||
TaskState,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Node as BASE_NODE
|
||||
@ -32,14 +31,21 @@ from onefuzztypes.models import (
|
||||
WorkUnitSummary,
|
||||
)
|
||||
from onefuzztypes.primitives import PoolName, Region
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .__version__ import __version__
|
||||
from .azure.auth import build_auth
|
||||
from .azure.creds import get_fuzz_storage
|
||||
from .azure.creds import get_func_storage, get_fuzz_storage
|
||||
from .azure.image import get_os
|
||||
from .azure.network import Network
|
||||
from .azure.queue import create_queue, delete_queue, peek_queue, queue_object
|
||||
from .azure.queue import (
|
||||
clear_queue,
|
||||
create_queue,
|
||||
delete_queue,
|
||||
peek_queue,
|
||||
queue_object,
|
||||
remove_first_message,
|
||||
)
|
||||
from .azure.vmss import (
|
||||
UnableToUpdate,
|
||||
create_vmss,
|
||||
@ -99,7 +105,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
# azure table query always return false when the column does not exist
|
||||
# We write the query this way to allow us to get the nodes where the
|
||||
# version is not defined as well as the nodes with a mismatched version
|
||||
version_query = "not (version ne '%s')" % __version__
|
||||
version_query = "not (version eq '%s')" % __version__
|
||||
return cls.search(query=query, raw_unchecked_filter=version_query)
|
||||
|
||||
@classmethod
|
||||
@ -154,16 +160,90 @@ class Node(BASE_NODE, ORMMixin):
|
||||
for node in nodes:
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
logging.info(
|
||||
"stopping task %s on machine_id:%s",
|
||||
task_id,
|
||||
"stopping machine_id:%s running task:%s",
|
||||
node.machine_id,
|
||||
task_id,
|
||||
)
|
||||
node.state = NodeState.done
|
||||
node.save()
|
||||
node.stop()
|
||||
|
||||
def mark_tasks_stopped_early(self) -> None:
|
||||
from .tasks.main import Task
|
||||
|
||||
for entry in NodeTasks.get_by_machine_id(self.machine_id):
|
||||
task = Task.get_by_task_id(entry.task_id)
|
||||
if isinstance(task, Task):
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=["node reimaged during task execution"],
|
||||
)
|
||||
)
|
||||
entry.delete()
|
||||
|
||||
def could_shrink_scaleset(self) -> bool:
|
||||
if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink():
|
||||
return True
|
||||
return False
|
||||
|
||||
def can_process_new_work(self) -> bool:
|
||||
if self.is_outdated():
|
||||
logging.info(
|
||||
"can_schedule old version machine_id:%s version:%s",
|
||||
self.machine_id,
|
||||
self.version,
|
||||
)
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
if self.delete_requested or self.reimage_requested:
|
||||
logging.info(
|
||||
"can_schedule should be recycled. machine_id:%s", self.machine_id
|
||||
)
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
if self.could_shrink_scaleset():
|
||||
self.set_halt()
|
||||
logging.info("node scheduled to shrink. machine_id:%s", self.machine_id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def is_outdated(self) -> bool:
|
||||
return self.version != __version__
|
||||
|
||||
def send_message(self, message: NodeCommand) -> None:
|
||||
stop_message = NodeMessage(
|
||||
agent_id=self.machine_id,
|
||||
message=message,
|
||||
)
|
||||
stop_message.save()
|
||||
|
||||
def to_reimage(self, done: bool = False) -> None:
|
||||
if done:
|
||||
if self.state not in NodeState.ready_for_reset():
|
||||
self.state = NodeState.done
|
||||
|
||||
if not self.reimage_requested and not self.delete_requested:
|
||||
logging.info("setting reimage_requested: %s", self.machine_id)
|
||||
self.reimage_requested = True
|
||||
self.save()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.to_reimage()
|
||||
self.send_message(NodeCommand(stop=StopNodeCommand()))
|
||||
|
||||
def set_shutdown(self) -> None:
|
||||
# don't give out more work to the node, but let it finish existing work
|
||||
logging.info("setting delete_requested: %s", self.machine_id)
|
||||
self.delete_requested = True
|
||||
self.save()
|
||||
|
||||
def set_halt(self) -> None:
|
||||
""" Tell the node to stop everything. """
|
||||
self.set_shutdown()
|
||||
self.stop()
|
||||
|
||||
|
||||
class NodeTasks(BASE_NODE_TASK, ORMMixin):
|
||||
@classmethod
|
||||
@ -384,8 +464,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
scaleset.save()
|
||||
|
||||
for node in nodes:
|
||||
node.state = NodeState.shutdown
|
||||
node.save()
|
||||
node.set_shutdown()
|
||||
|
||||
self.save()
|
||||
|
||||
@ -405,13 +484,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
scaleset.save()
|
||||
|
||||
for node in nodes:
|
||||
logging.info(
|
||||
"deleting node from pool: %s (%s) - machine_id:%s",
|
||||
self.pool_id,
|
||||
self.name,
|
||||
node.machine_id,
|
||||
)
|
||||
node.delete()
|
||||
node.set_halt()
|
||||
|
||||
self.save()
|
||||
|
||||
@ -496,6 +569,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
def init(self) -> None:
|
||||
logging.info("scaleset init: %s", self.scaleset_id)
|
||||
|
||||
ScalesetShrinkQueue(self.scaleset_id).create()
|
||||
|
||||
# Handle the race condition between a pool being deleted and a
|
||||
# scaleset being added to the pool.
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
@ -596,49 +671,62 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
# result = 'did I modify the scaleset in azure'
|
||||
def cleanup_nodes(self) -> bool:
|
||||
if self.state == ScalesetState.halt:
|
||||
logging.info("halting scaleset: %s", self.scaleset_id)
|
||||
self.halt()
|
||||
return True
|
||||
|
||||
to_reimage = []
|
||||
to_delete = []
|
||||
|
||||
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
|
||||
for node in outdated:
|
||||
logging.info(
|
||||
"node is outdated: %s - node_version:%s api_version:%s",
|
||||
node.machine_id,
|
||||
node.version,
|
||||
__version__,
|
||||
)
|
||||
if node.version == "1.0.0":
|
||||
node.state = NodeState.done
|
||||
to_reimage.append(node)
|
||||
else:
|
||||
node.to_reimage()
|
||||
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
|
||||
)
|
||||
|
||||
outdated = Node.search_outdated(
|
||||
scaleset_id=self.scaleset_id,
|
||||
states=[NodeState.free],
|
||||
)
|
||||
|
||||
if not (nodes or outdated):
|
||||
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
|
||||
if not outdated and not nodes:
|
||||
logging.info("no nodes need updating: %s", self.scaleset_id)
|
||||
return False
|
||||
|
||||
to_delete = []
|
||||
to_reimage = []
|
||||
|
||||
for node in outdated:
|
||||
if node.version == "1.0.0":
|
||||
to_reimage.append(node)
|
||||
else:
|
||||
stop_message = NodeMessage(
|
||||
agent_id=node.machine_id,
|
||||
message=NodeCommand(stop=StopNodeCommand()),
|
||||
)
|
||||
stop_message.save()
|
||||
# ground truth of existing nodes
|
||||
azure_nodes = list_instance_ids(self.scaleset_id)
|
||||
|
||||
for node in nodes:
|
||||
# delete nodes that are not waiting on the scaleset GC
|
||||
if not node.scaleset_node_exists():
|
||||
if node.machine_id not in azure_nodes:
|
||||
logging.info(
|
||||
"no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id
|
||||
)
|
||||
node.delete()
|
||||
elif node.state in [NodeState.shutdown, NodeState.halt]:
|
||||
elif node.delete_requested:
|
||||
to_delete.append(node)
|
||||
else:
|
||||
to_reimage.append(node)
|
||||
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
|
||||
node.set_halt()
|
||||
to_delete.append(node)
|
||||
else:
|
||||
to_reimage.append(node)
|
||||
|
||||
# Perform operations until they fail due to scaleset getting locked
|
||||
try:
|
||||
if to_delete:
|
||||
logging.info(
|
||||
"deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete)
|
||||
)
|
||||
self.delete_nodes(to_delete)
|
||||
for node in to_delete:
|
||||
node.set_halt()
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
|
||||
@ -646,69 +734,65 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.reimage_nodes(to_reimage)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset update already in progress: %s", self.scaleset_id)
|
||||
|
||||
return True
|
||||
|
||||
def resize(self) -> None:
|
||||
logging.info(
|
||||
"scaleset resize: %s - current: %s new: %s",
|
||||
self.scaleset_id,
|
||||
self.size,
|
||||
self.new_size,
|
||||
)
|
||||
|
||||
# no work needed to resize
|
||||
if self.new_size is None:
|
||||
def _resize_equal(self) -> None:
|
||||
# NOTE: this is the only place we reset to the 'running' state.
|
||||
# This ensures that our idea of scaleset size agrees with Azure
|
||||
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
|
||||
if node_count == self.size:
|
||||
logging.info("resize finished: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.running
|
||||
self.save()
|
||||
return
|
||||
else:
|
||||
logging.info(
|
||||
"resize is finished, waiting for nodes to check in: "
|
||||
"%s (%d of %d nodes checked in)",
|
||||
self.scaleset_id,
|
||||
node_count,
|
||||
self.size,
|
||||
)
|
||||
return
|
||||
|
||||
def _resize_grow(self) -> None:
|
||||
try:
|
||||
resize_vmss(self.scaleset_id, self.size)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset is mid-operation already")
|
||||
return
|
||||
|
||||
def _resize_shrink(self, to_remove: int) -> None:
|
||||
queue = ScalesetShrinkQueue(self.scaleset_id)
|
||||
for _ in range(to_remove):
|
||||
queue.add_entry()
|
||||
|
||||
def resize(self) -> None:
|
||||
# no longer needing to resize
|
||||
if self.state != ScalesetState.resize:
|
||||
return
|
||||
|
||||
logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size)
|
||||
|
||||
# reset the node delete queue
|
||||
ScalesetShrinkQueue(self.scaleset_id).clear()
|
||||
|
||||
# just in case, always ensure size is within max capacity
|
||||
self.new_size = min(self.new_size, self.max_size())
|
||||
self.size = min(self.size, self.max_size())
|
||||
|
||||
# Treat Azure knowledge of the size of the scaleset as "ground truth"
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
if size is None:
|
||||
logging.info("scaleset is unavailable. Re-queuing")
|
||||
self.save()
|
||||
logging.info("scaleset is unavailable: %s", self.scaleset_id)
|
||||
return
|
||||
|
||||
if size == self.new_size:
|
||||
# NOTE: this is the only place we reset to the 'running' state.
|
||||
# This ensures that our idea of scaleset size agrees with Azure
|
||||
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
|
||||
if node_count == self.size:
|
||||
logging.info("resize finished: %s", self.scaleset_id)
|
||||
self.new_size = None
|
||||
self.state = ScalesetState.running
|
||||
else:
|
||||
logging.info(
|
||||
"resize is finished, waiting for nodes to check in: "
|
||||
"%s (%d of %d nodes checked in)",
|
||||
self.scaleset_id,
|
||||
node_count,
|
||||
self.size,
|
||||
)
|
||||
# When adding capacity, call the resize API directly
|
||||
elif self.new_size > self.size:
|
||||
try:
|
||||
resize_vmss(self.scaleset_id, self.new_size)
|
||||
except UnableToUpdate:
|
||||
logging.info("scaleset is mid-operation already")
|
||||
# Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node
|
||||
# to pick up that the scaleset is too big upon task completion
|
||||
if size == self.size:
|
||||
self._resize_equal()
|
||||
elif self.size > size:
|
||||
self._resize_grow()
|
||||
else:
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free]
|
||||
)
|
||||
for node in nodes:
|
||||
if size > self.new_size:
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
size -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
self.save()
|
||||
self._resize_shrink(size - self.size)
|
||||
|
||||
def delete_nodes(self, nodes: List[Node]) -> None:
|
||||
if not nodes:
|
||||
@ -723,31 +807,12 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
|
||||
logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
|
||||
delete_vmss_nodes(self.scaleset_id, machine_ids)
|
||||
self.size -= len(machine_ids)
|
||||
self.save()
|
||||
|
||||
def reimage_nodes(self, nodes: List[Node]) -> None:
|
||||
from .tasks.main import Task
|
||||
|
||||
if not nodes:
|
||||
logging.debug("no nodes to reimage")
|
||||
return
|
||||
|
||||
for node in nodes:
|
||||
for entry in NodeTasks.get_by_machine_id(node.machine_id):
|
||||
task = Task.get_by_task_id(entry.task_id)
|
||||
if isinstance(task, Task):
|
||||
if task.state in [TaskState.stopping, TaskState.stopped]:
|
||||
continue
|
||||
|
||||
task.mark_failed(
|
||||
Error(
|
||||
code=ErrorCode.TASK_FAILED,
|
||||
errors=["node reimaged during task execution"],
|
||||
)
|
||||
)
|
||||
entry.delete()
|
||||
|
||||
if self.state == ScalesetState.shutdown:
|
||||
self.delete_nodes(nodes)
|
||||
return
|
||||
@ -766,28 +831,28 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logging.info("scaleset shutdown: %s", self.scaleset_id)
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
|
||||
if size is None or size == 0:
|
||||
self.state = ScalesetState.halt
|
||||
self.halt()
|
||||
return
|
||||
self.save()
|
||||
|
||||
def halt(self) -> None:
|
||||
self.state = ScalesetState.halt
|
||||
ScalesetShrinkQueue(self.scaleset_id).delete()
|
||||
|
||||
for node in Node.search_states(scaleset_id=self.scaleset_id):
|
||||
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
|
||||
node.delete()
|
||||
|
||||
vmss = get_vmss(self.scaleset_id)
|
||||
if vmss is None:
|
||||
logging.info("scaleset deleted: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.halt
|
||||
self.delete()
|
||||
else:
|
||||
if vmss:
|
||||
logging.info("scaleset deleting: %s", self.scaleset_id)
|
||||
delete_vmss(self.scaleset_id)
|
||||
self.save()
|
||||
else:
|
||||
logging.info("scaleset deleted: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.halt
|
||||
self.delete()
|
||||
|
||||
def max_size(self) -> int:
|
||||
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
|
||||
@ -858,3 +923,30 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("pool_name", "scaleset_id")
|
||||
|
||||
|
||||
class ShrinkEntry(BaseModel):
|
||||
shrink_id: UUID = Field(default_factory=uuid4)
|
||||
|
||||
|
||||
class ScalesetShrinkQueue:
|
||||
def __init__(self, scaleset_id: UUID):
|
||||
self.scaleset_id = scaleset_id
|
||||
|
||||
def queue_name(self) -> str:
|
||||
return "to-shrink-%s" % self.scaleset_id.hex
|
||||
|
||||
def clear(self) -> None:
|
||||
clear_queue(self.queue_name(), account_id=get_func_storage())
|
||||
|
||||
def create(self) -> None:
|
||||
create_queue(self.queue_name(), account_id=get_func_storage())
|
||||
|
||||
def delete(self) -> None:
|
||||
delete_queue(self.queue_name(), account_id=get_func_storage())
|
||||
|
||||
def add_entry(self) -> None:
|
||||
queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage())
|
||||
|
||||
def should_shrink(self) -> bool:
|
||||
return remove_first_message(self.queue_name(), account_id=get_func_storage())
|
||||
|
@ -108,7 +108,6 @@ class Task(BASE_TASK, ORMMixin):
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
# TODO: we need to tell every node currently working on this task to stop
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
|
||||
self.state = TaskState.stopping
|
||||
@ -154,6 +153,11 @@ class Task(BASE_TASK, ORMMixin):
|
||||
task = tasks[0]
|
||||
return task
|
||||
|
||||
def mark_stopping(self) -> None:
|
||||
if self.state not in [TaskState.stopped, TaskState.stopping]:
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
|
||||
def mark_failed(self, error: Error) -> None:
|
||||
if self.state in [TaskState.stopped, TaskState.stopping]:
|
||||
logging.debug(
|
||||
|
@ -31,5 +31,5 @@ pydantic~=1.6.1
|
||||
PyJWT~=1.7.1
|
||||
requests~=2.24.0
|
||||
memoization~=0.3.1
|
||||
# onefuzztypes version is set during build
|
||||
# onefuzz types version is set during build
|
||||
onefuzztypes==0.0.0
|
||||
|
@ -109,6 +109,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
|
||||
scaleset.state = ScalesetState.halt
|
||||
else:
|
||||
scaleset.state = ScalesetState.shutdown
|
||||
|
||||
scaleset.save()
|
||||
scaleset.auth = None
|
||||
return ok(scaleset)
|
||||
@ -133,7 +134,7 @@ def patch(req: func.HttpRequest) -> func.HttpResponse:
|
||||
)
|
||||
|
||||
if request.size is not None:
|
||||
scaleset.new_size = request.size
|
||||
scaleset.size = request.size
|
||||
scaleset.state = ScalesetState.resize
|
||||
|
||||
scaleset.save()
|
||||
|
@ -13,21 +13,28 @@ from ..onefuzzlib.pools import Scaleset
|
||||
|
||||
|
||||
def process_scaleset(scaleset: Scaleset) -> None:
|
||||
if scaleset.state == ScalesetState.halt:
|
||||
scaleset.halt()
|
||||
return
|
||||
logging.debug("checking scaleset for updates: %s", scaleset.scaleset_id)
|
||||
|
||||
if scaleset.state == ScalesetState.resize:
|
||||
scaleset.resize()
|
||||
|
||||
# if the scaleset is touched during cleanup, don't continue to process it
|
||||
if scaleset.cleanup_nodes():
|
||||
logging.debug("scaleset needed cleanup: %s", scaleset.scaleset_id)
|
||||
return
|
||||
|
||||
if scaleset.state in ScalesetState.needs_work():
|
||||
if (
|
||||
scaleset.state in ScalesetState.needs_work()
|
||||
and scaleset.state != ScalesetState.resize
|
||||
):
|
||||
logging.info(
|
||||
"exec scaleset state: %s - %s",
|
||||
scaleset.scaleset_id,
|
||||
scaleset.state.name,
|
||||
scaleset.state,
|
||||
)
|
||||
getattr(scaleset, scaleset.state.name)()
|
||||
|
||||
if hasattr(scaleset, scaleset.state.name):
|
||||
getattr(scaleset, scaleset.state.name)()
|
||||
return
|
||||
|
||||
|
||||
|
@ -399,6 +399,8 @@ class Node(BaseModel):
|
||||
scaleset_id: Optional[UUID] = None
|
||||
tasks: Optional[List[Tuple[UUID, NodeTaskState]]] = None
|
||||
version: str = Field(default="1.0.0")
|
||||
reimage_requested: bool = Field(default=False)
|
||||
delete_requested: bool = Field(default=False)
|
||||
|
||||
|
||||
class ScalesetSummary(BaseModel):
|
||||
@ -447,7 +449,6 @@ class Scaleset(BaseModel):
|
||||
image: str
|
||||
region: Region
|
||||
size: int
|
||||
new_size: Optional[int]
|
||||
spot_instances: bool
|
||||
error: Optional[Error]
|
||||
nodes: Optional[List[ScalesetNodeState]]
|
||||
|
Reference in New Issue
Block a user