reimage outdated nodes (#33)

* reimage outdated nodes

* import sort, version check

* clear node messages on registration

Co-authored-by: bmc-msft <41130664+bmc-msft@users.noreply.github.com>
This commit is contained in:
Cheick Keita
2020-09-29 11:59:03 -07:00
committed by GitHub
parent 35aac1122f
commit 5cab62b310
4 changed files with 58 additions and 10 deletions

View File

@ -26,6 +26,7 @@ from onefuzztypes.models import Scaleset as BASE_SCALESET
from onefuzztypes.models import (
ScalesetNodeState,
ScalesetSummary,
StopNodeCommand,
WorkSet,
WorkSetSummary,
WorkUnitSummary,
@ -80,6 +81,28 @@ class Node(BASE_NODE, ORMMixin):
query["pool_name"] = [pool_name]
return cls.search(query=query)
@classmethod
def search_outdated(
cls,
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
) -> List["Node"]:
query: QueryFilter = {}
if scaleset_id:
query["scaleset_id"] = [scaleset_id]
if states:
query["state"] = states
if pool_name:
query["pool_name"] = [pool_name]
# azure table query always return false when the column does not exist
# We write the query this way to allow us to get the nodes where the
# version is not defined as well as the nodes with a mismatched version
version_query = "not (version ne '%s')" % __version__
return cls.search(query=query, raw_unchecked_filter=version_query)
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]})
@ -132,9 +155,7 @@ class Node(BASE_NODE, ORMMixin):
for node in nodes:
if node.state not in NodeState.ready_for_reset():
logging.info(
"stopping task %s on machine_id:%s",
task_id,
node.machine_id,
"stopping task %s on machine_id:%s", task_id, node.machine_id,
)
node.state = NodeState.done
node.save()
@ -203,6 +224,12 @@ class NodeMessage(ORMMixin):
client.commit_batch(cls.table_name(), batch)
@classmethod
def clear_messages(cls, agent_id: UUID) -> None:
messages = cls.get_messages(agent_id)
message_ids = [m.message_id for m in messages]
cls.delete_messages(agent_id, message_ids)
class Pool(BASE_POOL, ORMMixin):
@classmethod
@ -569,13 +596,29 @@ class Scaleset(BASE_SCALESET, ORMMixin):
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
)
if not nodes:
outdated = Node.search_outdated(
scaleset_id=self.scaleset_id,
states=[NodeState.free],
)
if not (nodes or outdated):
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
return False
to_delete = []
to_reimage = []
for node in outdated:
if node.version == "1.0.0":
to_reimage.append(node)
else:
stop_message = NodeMessage(
agent_id=node.machine_id,
message=NodeCommand(stop=StopNodeCommand()),
)
stop_message.save()
for node in nodes:
# delete nodes that are not waiting on the scaleset GC
if not node.scaleset_node_exists():
@ -779,8 +822,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
break
if not node_state:
node_state = ScalesetNodeState(
machine_id=machine_id,
instance_id=instance_id,
machine_id=machine_id, instance_id=instance_id,
)
self.nodes.append(node_state)