allow nodes with multiple tasks to continue on task stop (#567)

As is, when multiple tasks are running on a single node, if any one of them stops, the node gets reimaged.

This changes the behavior such that when a node with multiple tasks has one task stop, the other tasks will continue.
This commit is contained in:
bmc-msft
2021-02-19 18:54:26 -05:00
committed by GitHub
parent 6ba5795f36
commit feb80ecb54
2 changed files with 41 additions and 6 deletions

View File

@ -8,7 +8,7 @@ import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from uuid import UUID from uuid import UUID
from onefuzztypes.enums import ErrorCode, NodeState from onefuzztypes.enums import ErrorCode, NodeState, TaskState
from onefuzztypes.events import ( from onefuzztypes.events import (
EventNodeCreated, EventNodeCreated,
EventNodeDeleted, EventNodeDeleted,
@ -18,7 +18,7 @@ from onefuzztypes.models import Error
from onefuzztypes.models import Node as BASE_NODE from onefuzztypes.models import Node as BASE_NODE
from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey
from onefuzztypes.models import NodeTasks as BASE_NODE_TASK from onefuzztypes.models import NodeTasks as BASE_NODE_TASK
from onefuzztypes.models import Result, StopNodeCommand from onefuzztypes.models import Result, StopNodeCommand, StopTaskNodeCommand
from onefuzztypes.primitives import PoolName from onefuzztypes.primitives import PoolName
from pydantic import Field from pydantic import Field
@ -134,6 +134,14 @@ class Node(BASE_NODE, ORMMixin):
else: else:
node.to_reimage() node.to_reimage()
@classmethod
def cleanup_busy_nodes_without_work(cls) -> None:
# There is a potential race condition if multiple `Node.stop_task` calls
# are made concurrently. By performing this check regularly, any nodes
# that hit this race condition will get cleaned up.
for node in cls.search_states(states=[NodeState.busy]):
node.stop_if_complete()
@classmethod @classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]: def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]}) nodes = cls.search(query={"machine_id": [machine_id]})
@ -178,13 +186,39 @@ class Node(BASE_NODE, ORMMixin):
# gracefully # gracefully
nodes = NodeTasks.get_nodes_by_task_id(task_id) nodes = NodeTasks.get_nodes_by_task_id(task_id)
for node in nodes: for node in nodes:
if node.state not in NodeState.ready_for_reset(): node.send_message(
NodeCommand(stop_task=StopTaskNodeCommand(task_id=task_id))
)
if not node.stop_if_complete():
logging.info( logging.info(
"stopping machine_id:%s running task:%s", "nodes: stopped task on node, "
node.machine_id, "but not reimaging due to other tasks: task_id:%s machine_id:%s",
task_id, task_id,
node.machine_id,
) )
node.stop()
def stop_if_complete(self) -> bool:
# returns True on stopping the node and False if this doesn't stop the node
from ..tasks.main import Task
node_tasks = NodeTasks.get_by_machine_id(self.machine_id)
for node_task in node_tasks:
task = Task.get_by_task_id(node_task.task_id)
# ignore invalid tasks when deciding if the node should be
# shutdown
if isinstance(task, Error):
continue
if task.state not in TaskState.shutting_down():
return False
logging.info(
"node: stopping busy node with all tasks complete: %s",
self.machine_id,
)
self.stop()
return True
def mark_tasks_stopped_early(self) -> None: def mark_tasks_stopped_early(self) -> None:
from ..tasks.main import Task from ..tasks.main import Task

View File

@ -50,6 +50,7 @@ def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa:
# get to empty scalesets, which can safely be deleted. # get to empty scalesets, which can safely be deleted.
Node.mark_outdated_nodes() Node.mark_outdated_nodes()
Node.cleanup_busy_nodes_without_work()
nodes = Node.search_states(states=NodeState.needs_work()) nodes = Node.search_states(states=NodeState.needs_work())
for node in sorted(nodes, key=lambda x: x.machine_id): for node in sorted(nodes, key=lambda x: x.machine_id):
logging.info("update node: %s", node.machine_id) logging.info("update node: %s", node.machine_id)