reimage outdated nodes (#33)

* reimage outdated nodes

* import sort, version check

* clear node messages on registration

Co-authored-by: bmc-msft <41130664+bmc-msft@users.noreply.github.com>
This commit is contained in:
Cheick Keita
2020-09-29 11:59:03 -07:00
committed by GitHub
parent 35aac1122f
commit 5cab62b310
4 changed files with 58 additions and 10 deletions

View File

@ -115,7 +115,9 @@ impl Registration {
if managed {
let scaleset = onefuzz::machine_id::get_scaleset_name().await?;
url.query_pairs_mut().append_pair("scaleset_id", &scaleset);
url.query_pairs_mut()
.append_pair("scaleset_id", &scaleset)
.append_pair("version", env!("ONEFUZZ_VERSION"));
}
// The registration can fail because this call is made before the virtual machine scaleset is done provisioning
// The authentication layer of the service will reject this request when that happens

View File

@ -14,7 +14,7 @@ from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.agent_authorization import verify_token
from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_name
from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.pools import Node, Pool
from ..onefuzzlib.pools import Node, NodeMessage, Pool
from ..onefuzzlib.request import not_ok, ok, parse_uri
@ -44,7 +44,6 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(get_registration, Error):
return not_ok(get_registration, context="agent registration")
# check if an existone registration exists
agent_node = Node.get_by_machine_id(get_registration.machine_id)
if agent_node is None:
@ -79,7 +78,6 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
registration_request = parse_uri(AgentRegistrationPost, req)
if isinstance(registration_request, Error):
return not_ok(registration_request, context="agent registration")
# check if an existone registration exists
agent_node = Node.get_by_machine_id(registration_request.machine_id)
pool = Pool.get_by_name(registration_request.pool_name)
@ -97,8 +95,13 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
pool_name=registration_request.pool_name,
machine_id=registration_request.machine_id,
scaleset_id=registration_request.scaleset_id,
version=registration_request.version
)
agent_node.save()
elif agent_node.version.lower != registration_request.version:
NodeMessage.clear_messages(agent_node.machine_id)
agent_node.version = registration_request.version
agent_node.save()
return create_registration_response(agent_node.machine_id, pool)

View File

@ -26,6 +26,7 @@ from onefuzztypes.models import Scaleset as BASE_SCALESET
from onefuzztypes.models import (
ScalesetNodeState,
ScalesetSummary,
StopNodeCommand,
WorkSet,
WorkSetSummary,
WorkUnitSummary,
@ -80,6 +81,28 @@ class Node(BASE_NODE, ORMMixin):
query["pool_name"] = [pool_name]
return cls.search(query=query)
@classmethod
def search_outdated(
cls,
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
) -> List["Node"]:
query: QueryFilter = {}
if scaleset_id:
query["scaleset_id"] = [scaleset_id]
if states:
query["state"] = states
if pool_name:
query["pool_name"] = [pool_name]
# azure table query always return false when the column does not exist
# We write the query this way to allow us to get the nodes where the
# version is not defined as well as the nodes with a mismatched version
version_query = "not (version ne '%s')" % __version__
return cls.search(query=query, raw_unchecked_filter=version_query)
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]})
@ -132,9 +155,7 @@ class Node(BASE_NODE, ORMMixin):
for node in nodes:
if node.state not in NodeState.ready_for_reset():
logging.info(
"stopping task %s on machine_id:%s",
task_id,
node.machine_id,
"stopping task %s on machine_id:%s", task_id, node.machine_id,
)
node.state = NodeState.done
node.save()
@ -203,6 +224,12 @@ class NodeMessage(ORMMixin):
client.commit_batch(cls.table_name(), batch)
@classmethod
def clear_messages(cls, agent_id: UUID) -> None:
messages = cls.get_messages(agent_id)
message_ids = [m.message_id for m in messages]
cls.delete_messages(agent_id, message_ids)
class Pool(BASE_POOL, ORMMixin):
@classmethod
@ -569,13 +596,29 @@ class Scaleset(BASE_SCALESET, ORMMixin):
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
)
if not nodes:
outdated = Node.search_outdated(
scaleset_id=self.scaleset_id,
states=[NodeState.free],
)
if not (nodes or outdated):
logging.debug("scaleset node gc done (no nodes) %s", self.scaleset_id)
return False
to_delete = []
to_reimage = []
for node in outdated:
if node.version == "1.0.0":
to_reimage.append(node)
else:
stop_message = NodeMessage(
agent_id=node.machine_id,
message=NodeCommand(stop=StopNodeCommand()),
)
stop_message.save()
for node in nodes:
# delete nodes that are not waiting on the scaleset GC
if not node.scaleset_node_exists():
@ -779,8 +822,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
break
if not node_state:
node_state = ScalesetNodeState(
machine_id=machine_id,
instance_id=instance_id,
machine_id=machine_id, instance_id=instance_id,
)
self.nodes.append(node_state)

View File

@ -82,6 +82,7 @@ class AgentRegistrationPost(BaseRequest):
pool_name: PoolName
scaleset_id: Optional[UUID]
machine_id: UUID
version: str
class PoolCreate(BaseRequest):