refactor node state to fully put the agent in charge (#90)

This commit is contained in:
bmc-msft
2020-10-03 02:43:04 -04:00
committed by GitHub
parent a088e72299
commit e308a4ae1e
12 changed files with 326 additions and 198 deletions

View File

@ -3,16 +3,14 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import logging
import azure.functions as func import azure.functions as func
from onefuzztypes.enums import ErrorCode, TaskState from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error, NodeCommand, StopNodeCommand from onefuzztypes.models import Error
from onefuzztypes.requests import CanScheduleRequest from onefuzztypes.requests import CanScheduleRequest
from onefuzztypes.responses import CanSchedule from onefuzztypes.responses import CanSchedule
from ..onefuzzlib.agent_authorization import verify_token from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.pools import Node, NodeMessage from ..onefuzzlib.pools import Node
from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.tasks.main import Task from ..onefuzzlib.tasks.main import Task
@ -31,23 +29,13 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
allowed = True allowed = True
work_stopped = False work_stopped = False
if node.is_outdated():
logging.info( if not node.can_process_new_work():
"received can_schedule request from outdated node '%s' version '%s'",
node.machine_id,
node.version,
)
allowed = False allowed = False
stop_message = NodeMessage(
agent_id=node.machine_id,
message=NodeCommand(stop=StopNodeCommand()),
)
stop_message.save()
task = Task.get_by_task_id(request.task_id) task = Task.get_by_task_id(request.task_id)
work_stopped = isinstance(task, Error) or (task.state != TaskState.scheduled) work_stopped = isinstance(task, Error) or task.state in TaskState.shutting_down()
if work_stopped: if work_stopped:
allowed = False allowed = False

View File

@ -47,11 +47,28 @@ def get_node_checked(machine_id: UUID) -> Node:
def on_state_update( def on_state_update(
machine_id: UUID, machine_id: UUID,
state_update: NodeStateUpdate, state_update: NodeStateUpdate,
) -> func.HttpResponse: ) -> None:
state = state_update.state state = state_update.state
node = get_node_checked(machine_id) node = get_node_checked(machine_id)
if state == NodeState.init or node.state not in NodeState.ready_for_reset(): if state == NodeState.free:
if node.reimage_requested or node.delete_requested:
logging.info("stopping free node with reset flags: %s", node.machine_id)
node.stop()
return
if node.could_shrink_scaleset():
logging.info("stopping free node to resize scaleset: %s", node.machine_id)
node.set_halt()
return
if state == NodeState.init:
if node.delete_requested:
node.stop()
return
node.reimage_requested = False
node.save()
elif node.state not in NodeState.ready_for_reset():
if node.state != state: if node.state != state:
node.state = state node.state = state
node.save() node.save()
@ -91,26 +108,28 @@ def on_state_update(
) )
node_task.save() node_task.save()
elif state == NodeState.done: elif state == NodeState.done:
# if tasks are running on the node when it reports as Done
# those are stopped early
node.mark_tasks_stopped_early()
# Model-validated. # Model-validated.
# #
# This field will be required in the future. # This field will be required in the future.
# For now, it is optional for back compat. # For now, it is optional for back compat.
done_data = cast(Optional[NodeDoneEventData], state_update.data) done_data = cast(Optional[NodeDoneEventData], state_update.data)
if done_data: if done_data:
# TODO: do something with this done data
if done_data.error: if done_data.error:
logging.error( logging.error(
"node `done` with error: machine_id = %s, data = %s", "node 'done' with error: machine_id:%s, data:%s",
machine_id, machine_id,
done_data, done_data,
) )
else: else:
logging.info("ignoring state updates from the node: %s: %s", machine_id, state) logging.info("ignoring state updates from the node: %s: %s", machine_id, state)
return ok(BoolResult(result=True))
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> None:
def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
if event.running: if event.running:
task_id = event.running.task_id task_id = event.running.task_id
elif event.done: elif event.done:
@ -129,37 +148,32 @@ def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
task.state = TaskState.running task.state = TaskState.running
if node.state not in NodeState.ready_for_reset(): if node.state not in NodeState.ready_for_reset():
node.state = NodeState.busy node.state = NodeState.busy
node.save()
node_task.save() node_task.save()
# Start the clock for the task if it wasn't started already # Start the clock for the task if it wasn't started already
# (as happens in 1.0.0 agents) # (as happens in 1.0.0 agents)
task.on_start() task.on_start()
elif event.done: elif event.done:
# Only record exit status if the task isn't already shutting down.
#
# It's ok for the agent to fail because resources vanish out from underneath
# it during deletion.
if task.state not in TaskState.shutting_down():
exit_status = event.done.exit_status
if not exit_status.success:
logging.error("task failed: status = %s", exit_status)
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status = %s" % exit_status,
event.done.stdout,
event.done.stderr,
],
)
)
task.state = TaskState.stopping
if node.state not in NodeState.ready_for_reset():
node.state = NodeState.done
node_task.delete() node_task.delete()
exit_status = event.done.exit_status
if not exit_status.success:
logging.error("task failed. status:%s", exit_status)
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=[
"task failed. exit_status:%s" % exit_status,
event.done.stdout,
event.done.stderr,
],
)
)
else:
task.mark_stopping()
node.to_reimage(done=True)
else: else:
err = Error( err = Error(
code=ErrorCode.INVALID_REQUEST, code=ErrorCode.INVALID_REQUEST,
@ -168,11 +182,9 @@ def on_worker_event(machine_id: UUID, event: WorkerEvent) -> func.HttpResponse:
raise RequestException(err) raise RequestException(err)
task.save() task.save()
node.save()
task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event) task_event = TaskEvent(task_id=task_id, machine_id=machine_id, event_data=event)
task_event.save() task_event.save()
return ok(BoolResult(result=True))
def post(req: func.HttpRequest) -> func.HttpResponse: def post(req: func.HttpRequest) -> func.HttpResponse:
@ -197,9 +209,11 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(err, context=ERROR_CONTEXT) return not_ok(err, context=ERROR_CONTEXT)
if event.state_update: if event.state_update:
return on_state_update(envelope.machine_id, event.state_update) on_state_update(envelope.machine_id, event.state_update)
return ok(BoolResult(result=True))
elif event.worker_event: elif event.worker_event:
return on_worker_event(envelope.machine_id, event.worker_event) on_worker_event(envelope.machine_id, event.worker_event)
return ok(BoolResult(result=True))
else: else:
err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"]) err = Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid node event"])
return not_ok(err, context=ERROR_CONTEXT) return not_ok(err, context=ERROR_CONTEXT)

View File

@ -6,7 +6,7 @@
from uuid import UUID from uuid import UUID
import azure.functions as func import azure.functions as func
from onefuzztypes.enums import ErrorCode from onefuzztypes.enums import ErrorCode, NodeState
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration from onefuzztypes.responses import AgentRegistration
@ -78,7 +78,6 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
registration_request = parse_uri(AgentRegistrationPost, req) registration_request = parse_uri(AgentRegistrationPost, req)
if isinstance(registration_request, Error): if isinstance(registration_request, Error):
return not_ok(registration_request, context="agent registration") return not_ok(registration_request, context="agent registration")
agent_node = Node.get_by_machine_id(registration_request.machine_id)
pool = Pool.get_by_name(registration_request.pool_name) pool = Pool.get_by_name(registration_request.pool_name)
if isinstance(pool, Error): if isinstance(pool, Error):
@ -90,20 +89,23 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
context="agent registration", context="agent registration",
) )
if agent_node is None: node = Node.get_by_machine_id(registration_request.machine_id)
agent_node = Node( if node:
if node.version != registration_request.version:
NodeMessage.clear_messages(node.machine_id)
node.version = registration_request.version
node.reimage_requested = False
node.state = NodeState.init
else:
node = Node(
pool_name=registration_request.pool_name, pool_name=registration_request.pool_name,
machine_id=registration_request.machine_id, machine_id=registration_request.machine_id,
scaleset_id=registration_request.scaleset_id, scaleset_id=registration_request.scaleset_id,
version=registration_request.version, version=registration_request.version,
) )
agent_node.save() node.save()
elif agent_node.version.lower != registration_request.version:
NodeMessage.clear_messages(agent_node.machine_id)
agent_node.version = registration_request.version
agent_node.save()
return create_registration_response(agent_node.machine_id, pool) return create_registration_response(node.machine_id, pool)
def main(req: func.HttpRequest) -> func.HttpResponse: def main(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -4,7 +4,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import azure.functions as func import azure.functions as func
from onefuzztypes.enums import ErrorCode, NodeState from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.requests import NodeGet, NodeSearch from onefuzztypes.requests import NodeGet, NodeSearch
from onefuzztypes.responses import BoolResult from onefuzztypes.responses import BoolResult
@ -54,8 +54,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
context=request.machine_id, context=request.machine_id,
) )
node.state = NodeState.halt node.set_halt()
node.save()
return ok(BoolResult(result=True)) return ok(BoolResult(result=True))
@ -72,9 +71,7 @@ def patch(req: func.HttpRequest) -> func.HttpResponse:
context=request.machine_id, context=request.machine_id,
) )
node.state = NodeState.done node.stop()
node.save()
return ok(BoolResult(result=True)) return ok(BoolResult(result=True))

View File

@ -88,6 +88,15 @@ def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceC
return None return None
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
queue = get_queue(name, account_id=account_id)
if queue:
try:
queue.clear_messages()
except ResourceNotFoundError:
return None
def send_message( def send_message(
name: QueueNameType, name: QueueNameType,
message: bytes, message: bytes,
@ -102,6 +111,19 @@ def send_message(
pass pass
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
create_queue(name, account_id=account_id)
queue = get_queue(name, account_id=account_id)
if queue:
try:
for message in queue.receive_messages():
queue.delete_message(message)
return True
except ResourceNotFoundError:
return False
return False
A = TypeVar("A", bound=BaseModel) A = TypeVar("A", bound=BaseModel)

View File

@ -76,7 +76,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
def list_instance_ids(name: UUID) -> Dict[UUID, str]: def list_instance_ids(name: UUID) -> Dict[UUID, str]:
logging.info("get instance IDs for scaleset: %s", name) logging.debug("get instance IDs for scaleset: %s", name)
resource_group = get_base_resource_group() resource_group = get_base_resource_group()
compute_client = mgmt_client_factory(ComputeManagementClient) compute_client = mgmt_client_factory(ComputeManagementClient)

View File

@ -6,7 +6,7 @@
import datetime import datetime
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from uuid import UUID from uuid import UUID, uuid4
from onefuzztypes.enums import ( from onefuzztypes.enums import (
OS, OS,
@ -15,7 +15,6 @@ from onefuzztypes.enums import (
NodeState, NodeState,
PoolState, PoolState,
ScalesetState, ScalesetState,
TaskState,
) )
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.models import Node as BASE_NODE from onefuzztypes.models import Node as BASE_NODE
@ -32,14 +31,21 @@ from onefuzztypes.models import (
WorkUnitSummary, WorkUnitSummary,
) )
from onefuzztypes.primitives import PoolName, Region from onefuzztypes.primitives import PoolName, Region
from pydantic import Field from pydantic import BaseModel, Field
from .__version__ import __version__ from .__version__ import __version__
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.creds import get_fuzz_storage from .azure.creds import get_func_storage, get_fuzz_storage
from .azure.image import get_os from .azure.image import get_os
from .azure.network import Network from .azure.network import Network
from .azure.queue import create_queue, delete_queue, peek_queue, queue_object from .azure.queue import (
clear_queue,
create_queue,
delete_queue,
peek_queue,
queue_object,
remove_first_message,
)
from .azure.vmss import ( from .azure.vmss import (
UnableToUpdate, UnableToUpdate,
create_vmss, create_vmss,
@ -99,7 +105,7 @@ class Node(BASE_NODE, ORMMixin):
# azure table query always return false when the column does not exist # azure table query always return false when the column does not exist
# We write the query this way to allow us to get the nodes where the # We write the query this way to allow us to get the nodes where the
# version is not defined as well as the nodes with a mismatched version # version is not defined as well as the nodes with a mismatched version
version_query = "not (version ne '%s')" % __version__ version_query = "not (version eq '%s')" % __version__
return cls.search(query=query, raw_unchecked_filter=version_query) return cls.search(query=query, raw_unchecked_filter=version_query)
@classmethod @classmethod
@ -154,16 +160,90 @@ class Node(BASE_NODE, ORMMixin):
for node in nodes: for node in nodes:
if node.state not in NodeState.ready_for_reset(): if node.state not in NodeState.ready_for_reset():
logging.info( logging.info(
"stopping task %s on machine_id:%s", "stopping machine_id:%s running task:%s",
task_id,
node.machine_id, node.machine_id,
task_id,
) )
node.state = NodeState.done node.stop()
node.save()
def mark_tasks_stopped_early(self) -> None:
from .tasks.main import Task
for entry in NodeTasks.get_by_machine_id(self.machine_id):
task = Task.get_by_task_id(entry.task_id)
if isinstance(task, Task):
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=["node reimaged during task execution"],
)
)
entry.delete()
def could_shrink_scaleset(self) -> bool:
if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink():
return True
return False
def can_process_new_work(self) -> bool:
if self.is_outdated():
logging.info(
"can_schedule old version machine_id:%s version:%s",
self.machine_id,
self.version,
)
self.stop()
return False
if self.delete_requested or self.reimage_requested:
logging.info(
"can_schedule should be recycled. machine_id:%s", self.machine_id
)
self.stop()
return False
if self.could_shrink_scaleset():
self.set_halt()
logging.info("node scheduled to shrink. machine_id:%s", self.machine_id)
return False
return True
def is_outdated(self) -> bool: def is_outdated(self) -> bool:
return self.version != __version__ return self.version != __version__
def send_message(self, message: NodeCommand) -> None:
stop_message = NodeMessage(
agent_id=self.machine_id,
message=message,
)
stop_message.save()
def to_reimage(self, done: bool = False) -> None:
if done:
if self.state not in NodeState.ready_for_reset():
self.state = NodeState.done
if not self.reimage_requested and not self.delete_requested:
logging.info("setting reimage_requested: %s", self.machine_id)
self.reimage_requested = True
self.save()
def stop(self) -> None:
self.to_reimage()
self.send_message(NodeCommand(stop=StopNodeCommand()))
def set_shutdown(self) -> None:
# don't give out more work to the node, but let it finish existing work
logging.info("setting delete_requested: %s", self.machine_id)
self.delete_requested = True
self.save()
def set_halt(self) -> None:
""" Tell the node to stop everything. """
self.set_shutdown()
self.stop()
class NodeTasks(BASE_NODE_TASK, ORMMixin): class NodeTasks(BASE_NODE_TASK, ORMMixin):
@classmethod @classmethod
@ -384,8 +464,7 @@ class Pool(BASE_POOL, ORMMixin):
scaleset.save() scaleset.save()
for node in nodes: for node in nodes:
node.state = NodeState.shutdown node.set_shutdown()
node.save()
self.save() self.save()
@ -405,13 +484,7 @@ class Pool(BASE_POOL, ORMMixin):
scaleset.save() scaleset.save()
for node in nodes: for node in nodes:
logging.info( node.set_halt()
"deleting node from pool: %s (%s) - machine_id:%s",
self.pool_id,
self.name,
node.machine_id,
)
node.delete()
self.save() self.save()
@ -496,6 +569,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
def init(self) -> None: def init(self) -> None:
logging.info("scaleset init: %s", self.scaleset_id) logging.info("scaleset init: %s", self.scaleset_id)
ScalesetShrinkQueue(self.scaleset_id).create()
# Handle the race condition between a pool being deleted and a # Handle the race condition between a pool being deleted and a
# scaleset being added to the pool. # scaleset being added to the pool.
pool = Pool.get_by_name(self.pool_name) pool = Pool.get_by_name(self.pool_name)
@ -596,49 +671,62 @@ class Scaleset(BASE_SCALESET, ORMMixin):
# result = 'did I modify the scaleset in azure' # result = 'did I modify the scaleset in azure'
def cleanup_nodes(self) -> bool: def cleanup_nodes(self) -> bool:
if self.state == ScalesetState.halt: if self.state == ScalesetState.halt:
logging.info("halting scaleset: %s", self.scaleset_id)
self.halt() self.halt()
return True return True
to_reimage = []
to_delete = []
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.state = NodeState.done
to_reimage.append(node)
else:
node.to_reimage()
nodes = Node.search_states( nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset() scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
) )
outdated = Node.search_outdated( if not outdated and not nodes:
scaleset_id=self.scaleset_id, logging.info("no nodes need updating: %s", self.scaleset_id)
states=[NodeState.free],
)
if not (nodes or outdated):
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
return False return False
to_delete = [] # ground truth of existing nodes
to_reimage = [] azure_nodes = list_instance_ids(self.scaleset_id)
for node in outdated:
if node.version == "1.0.0":
to_reimage.append(node)
else:
stop_message = NodeMessage(
agent_id=node.machine_id,
message=NodeCommand(stop=StopNodeCommand()),
)
stop_message.save()
for node in nodes: for node in nodes:
# delete nodes that are not waiting on the scaleset GC if node.machine_id not in azure_nodes:
if not node.scaleset_node_exists(): logging.info(
"no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id
)
node.delete() node.delete()
elif node.state in [NodeState.shutdown, NodeState.halt]: elif node.delete_requested:
to_delete.append(node) to_delete.append(node)
else: else:
to_reimage.append(node) if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
node.set_halt()
to_delete.append(node)
else:
to_reimage.append(node)
# Perform operations until they fail due to scaleset getting locked # Perform operations until they fail due to scaleset getting locked
try: try:
if to_delete: if to_delete:
logging.info(
"deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete)
)
self.delete_nodes(to_delete) self.delete_nodes(to_delete)
for node in to_delete: for node in to_delete:
node.set_halt()
node.state = NodeState.halt node.state = NodeState.halt
node.save() node.save()
@ -646,69 +734,65 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.reimage_nodes(to_reimage) self.reimage_nodes(to_reimage)
except UnableToUpdate: except UnableToUpdate:
logging.info("scaleset update already in progress: %s", self.scaleset_id) logging.info("scaleset update already in progress: %s", self.scaleset_id)
return True return True
def resize(self) -> None: def _resize_equal(self) -> None:
logging.info( # NOTE: this is the only place we reset to the 'running' state.
"scaleset resize: %s - current: %s new: %s", # This ensures that our idea of scaleset size agrees with Azure
self.scaleset_id, node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
self.size, if node_count == self.size:
self.new_size, logging.info("resize finished: %s", self.scaleset_id)
)
# no work needed to resize
if self.new_size is None:
self.state = ScalesetState.running self.state = ScalesetState.running
self.save() self.save()
return return
else:
logging.info(
"resize is finished, waiting for nodes to check in: "
"%s (%d of %d nodes checked in)",
self.scaleset_id,
node_count,
self.size,
)
return
def _resize_grow(self) -> None:
try:
resize_vmss(self.scaleset_id, self.size)
except UnableToUpdate:
logging.info("scaleset is mid-operation already")
return
def _resize_shrink(self, to_remove: int) -> None:
queue = ScalesetShrinkQueue(self.scaleset_id)
for _ in range(to_remove):
queue.add_entry()
def resize(self) -> None:
# no longer needing to resize
if self.state != ScalesetState.resize:
return
logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size)
# reset the node delete queue
ScalesetShrinkQueue(self.scaleset_id).clear()
# just in case, always ensure size is within max capacity # just in case, always ensure size is within max capacity
self.new_size = min(self.new_size, self.max_size()) self.size = min(self.size, self.max_size())
# Treat Azure knowledge of the size of the scaleset as "ground truth" # Treat Azure knowledge of the size of the scaleset as "ground truth"
size = get_vmss_size(self.scaleset_id) size = get_vmss_size(self.scaleset_id)
if size is None: if size is None:
logging.info("scaleset is unavailable. Re-queuing") logging.info("scaleset is unavailable: %s", self.scaleset_id)
self.save()
return return
if size == self.new_size: if size == self.size:
# NOTE: this is the only place we reset to the 'running' state. self._resize_equal()
# This ensures that our idea of scaleset size agrees with Azure elif self.size > size:
node_count = len(Node.search_states(scaleset_id=self.scaleset_id)) self._resize_grow()
if node_count == self.size:
logging.info("resize finished: %s", self.scaleset_id)
self.new_size = None
self.state = ScalesetState.running
else:
logging.info(
"resize is finished, waiting for nodes to check in: "
"%s (%d of %d nodes checked in)",
self.scaleset_id,
node_count,
self.size,
)
# When adding capacity, call the resize API directly
elif self.new_size > self.size:
try:
resize_vmss(self.scaleset_id, self.new_size)
except UnableToUpdate:
logging.info("scaleset is mid-operation already")
# Shut down any nodes without work. Otherwise, rely on Scaleset.reimage_node
# to pick up that the scaleset is too big upon task completion
else: else:
nodes = Node.search_states( self._resize_shrink(size - self.size)
scaleset_id=self.scaleset_id, states=[NodeState.init, NodeState.free]
)
for node in nodes:
if size > self.new_size:
node.state = NodeState.halt
node.save()
size -= 1
else:
break
self.save()
def delete_nodes(self, nodes: List[Node]) -> None: def delete_nodes(self, nodes: List[Node]) -> None:
if not nodes: if not nodes:
@ -723,31 +807,12 @@ class Scaleset(BASE_SCALESET, ORMMixin):
logging.info("deleting %s:%s", self.scaleset_id, machine_ids) logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
delete_vmss_nodes(self.scaleset_id, machine_ids) delete_vmss_nodes(self.scaleset_id, machine_ids)
self.size -= len(machine_ids)
self.save()
def reimage_nodes(self, nodes: List[Node]) -> None: def reimage_nodes(self, nodes: List[Node]) -> None:
from .tasks.main import Task
if not nodes: if not nodes:
logging.debug("no nodes to reimage") logging.debug("no nodes to reimage")
return return
for node in nodes:
for entry in NodeTasks.get_by_machine_id(node.machine_id):
task = Task.get_by_task_id(entry.task_id)
if isinstance(task, Task):
if task.state in [TaskState.stopping, TaskState.stopped]:
continue
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=["node reimaged during task execution"],
)
)
entry.delete()
if self.state == ScalesetState.shutdown: if self.state == ScalesetState.shutdown:
self.delete_nodes(nodes) self.delete_nodes(nodes)
return return
@ -766,28 +831,28 @@ class Scaleset(BASE_SCALESET, ORMMixin):
) )
def shutdown(self) -> None: def shutdown(self) -> None:
logging.info("scaleset shutdown: %s", self.scaleset_id)
size = get_vmss_size(self.scaleset_id) size = get_vmss_size(self.scaleset_id)
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
if size is None or size == 0: if size is None or size == 0:
self.state = ScalesetState.halt
self.halt() self.halt()
return
self.save()
def halt(self) -> None: def halt(self) -> None:
self.state = ScalesetState.halt
ScalesetShrinkQueue(self.scaleset_id).delete()
for node in Node.search_states(scaleset_id=self.scaleset_id): for node in Node.search_states(scaleset_id=self.scaleset_id):
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id) logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
node.delete() node.delete()
vmss = get_vmss(self.scaleset_id) vmss = get_vmss(self.scaleset_id)
if vmss is None: if vmss:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
else:
logging.info("scaleset deleting: %s", self.scaleset_id) logging.info("scaleset deleting: %s", self.scaleset_id)
delete_vmss(self.scaleset_id) delete_vmss(self.scaleset_id)
self.save() self.save()
else:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
def max_size(self) -> int: def max_size(self) -> int:
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/ # https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
@ -858,3 +923,30 @@ class Scaleset(BASE_SCALESET, ORMMixin):
@classmethod @classmethod
def key_fields(cls) -> Tuple[str, str]: def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "scaleset_id") return ("pool_name", "scaleset_id")
class ShrinkEntry(BaseModel):
shrink_id: UUID = Field(default_factory=uuid4)
class ScalesetShrinkQueue:
def __init__(self, scaleset_id: UUID):
self.scaleset_id = scaleset_id
def queue_name(self) -> str:
return "to-shrink-%s" % self.scaleset_id.hex
def clear(self) -> None:
clear_queue(self.queue_name(), account_id=get_func_storage())
def create(self) -> None:
create_queue(self.queue_name(), account_id=get_func_storage())
def delete(self) -> None:
delete_queue(self.queue_name(), account_id=get_func_storage())
def add_entry(self) -> None:
queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage())
def should_shrink(self) -> bool:
return remove_first_message(self.queue_name(), account_id=get_func_storage())

View File

@ -108,7 +108,6 @@ class Task(BASE_TASK, ORMMixin):
self.save() self.save()
def stopping(self) -> None: def stopping(self) -> None:
# TODO: we need to tell every node currently working on this task to stop
# TODO: we need to 'unschedule' this task from the existing pools # TODO: we need to 'unschedule' this task from the existing pools
self.state = TaskState.stopping self.state = TaskState.stopping
@ -154,6 +153,11 @@ class Task(BASE_TASK, ORMMixin):
task = tasks[0] task = tasks[0]
return task return task
def mark_stopping(self) -> None:
if self.state not in [TaskState.stopped, TaskState.stopping]:
self.state = TaskState.stopping
self.save()
def mark_failed(self, error: Error) -> None: def mark_failed(self, error: Error) -> None:
if self.state in [TaskState.stopped, TaskState.stopping]: if self.state in [TaskState.stopped, TaskState.stopping]:
logging.debug( logging.debug(

View File

@ -31,5 +31,5 @@ pydantic~=1.6.1
PyJWT~=1.7.1 PyJWT~=1.7.1
requests~=2.24.0 requests~=2.24.0
memoization~=0.3.1 memoization~=0.3.1
# onefuzztypes version is set during build # onefuzz types version is set during build
onefuzztypes==0.0.0 onefuzztypes==0.0.0

View File

@ -109,6 +109,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
scaleset.state = ScalesetState.halt scaleset.state = ScalesetState.halt
else: else:
scaleset.state = ScalesetState.shutdown scaleset.state = ScalesetState.shutdown
scaleset.save() scaleset.save()
scaleset.auth = None scaleset.auth = None
return ok(scaleset) return ok(scaleset)
@ -133,7 +134,7 @@ def patch(req: func.HttpRequest) -> func.HttpResponse:
) )
if request.size is not None: if request.size is not None:
scaleset.new_size = request.size scaleset.size = request.size
scaleset.state = ScalesetState.resize scaleset.state = ScalesetState.resize
scaleset.save() scaleset.save()

View File

@ -13,21 +13,28 @@ from ..onefuzzlib.pools import Scaleset
def process_scaleset(scaleset: Scaleset) -> None: def process_scaleset(scaleset: Scaleset) -> None:
if scaleset.state == ScalesetState.halt: logging.debug("checking scaleset for updates: %s", scaleset.scaleset_id)
scaleset.halt()
return if scaleset.state == ScalesetState.resize:
scaleset.resize()
# if the scaleset is touched during cleanup, don't continue to process it # if the scaleset is touched during cleanup, don't continue to process it
if scaleset.cleanup_nodes(): if scaleset.cleanup_nodes():
logging.debug("scaleset needed cleanup: %s", scaleset.scaleset_id)
return return
if scaleset.state in ScalesetState.needs_work(): if (
scaleset.state in ScalesetState.needs_work()
and scaleset.state != ScalesetState.resize
):
logging.info( logging.info(
"exec scaleset state: %s - %s", "exec scaleset state: %s - %s",
scaleset.scaleset_id, scaleset.scaleset_id,
scaleset.state.name, scaleset.state,
) )
getattr(scaleset, scaleset.state.name)()
if hasattr(scaleset, scaleset.state.name):
getattr(scaleset, scaleset.state.name)()
return return

View File

@ -399,6 +399,8 @@ class Node(BaseModel):
scaleset_id: Optional[UUID] = None scaleset_id: Optional[UUID] = None
tasks: Optional[List[Tuple[UUID, NodeTaskState]]] = None tasks: Optional[List[Tuple[UUID, NodeTaskState]]] = None
version: str = Field(default="1.0.0") version: str = Field(default="1.0.0")
reimage_requested: bool = Field(default=False)
delete_requested: bool = Field(default=False)
class ScalesetSummary(BaseModel): class ScalesetSummary(BaseModel):
@ -447,7 +449,6 @@ class Scaleset(BaseModel):
image: str image: str
region: Region region: Region
size: int size: int
new_size: Optional[int]
spot_instances: bool spot_instances: bool
error: Optional[Error] error: Optional[Error]
nodes: Optional[List[ScalesetNodeState]] nodes: Optional[List[ScalesetNodeState]]