split out node, scaleset, and pool code (#507)

This commit is contained in:
bmc-msft
2021-02-04 19:07:49 -05:00
committed by GitHub
parent 81263c9065
commit a02e084522
24 changed files with 1302 additions and 1229 deletions

View File

@ -11,9 +11,9 @@ from onefuzztypes.responses import CanSchedule
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Node
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.workers.nodes import Node
def post(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -10,8 +10,8 @@ from onefuzztypes.responses import BoolResult, PendingNodeCommand
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import NodeMessage
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.nodes import NodeMessage
def get(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -18,8 +18,9 @@ from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Node, Pool
from ..onefuzzlib.request import not_ok, ok, parse_uri
from ..onefuzzlib.workers.nodes import Node
from ..onefuzzlib.workers.pools import Pool
def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpResponse:

View File

@ -11,8 +11,8 @@ from onefuzztypes.responses import BoolResult
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Node, NodeTasks
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.nodes import Node, NodeTasks
def get(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -11,8 +11,8 @@ from onefuzztypes.responses import BoolResult
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Node
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.nodes import Node
def post(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -25,9 +25,9 @@ from onefuzztypes.models import (
WorkerRunningEvent,
)
from ..onefuzzlib.pools import Node, NodeTasks
from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.workers.nodes import Node, NodeTasks
MAX_OUTPUT_SIZE = 4096

View File

@ -10,8 +10,10 @@ from typing import List
from onefuzztypes.enums import NodeState, ScalesetState
from onefuzztypes.models import AutoScaleConfig, TaskPool
from .pools import Node, Pool, Scaleset
from .tasks.main import Task
from .workers.nodes import Node
from .workers.pools import Pool
from .workers.scalesets import Scaleset
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:

View File

@ -13,9 +13,10 @@ from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, UserInfo
from .azure.creds import get_scaleset_principal_id
from .pools import Pool, Scaleset
from .request import not_ok
from .user_credentials import parse_jwt_token
from .workers.pools import Pool
from .workers.scalesets import Scaleset
@cached(ttl=60)

File diff suppressed because it is too large Load Diff

View File

@ -24,8 +24,10 @@ from ..azure.queue import create_queue, delete_queue
from ..azure.storage import StorageType
from ..events import send_event
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..pools import Node, Pool, Scaleset
from ..proxy_forward import ProxyForward
from ..workers.nodes import Node
from ..workers.pools import Pool
from ..workers.scalesets import Scaleset
class Task(BASE_TASK, ORMMixin):

View File

@ -14,7 +14,7 @@ from pydantic import BaseModel
from ..azure.containers import blob_exists, get_container_sas_url
from ..azure.storage import StorageType
from ..pools import Pool
from ..workers.pools import Pool
from .config import build_task_config, get_setup_container
from .main import Task

View File

@ -60,10 +60,12 @@ def queue_update(
def execute_update(update: Update) -> None:
from .jobs import Job
from .orm import ORMMixin
from .pools import Node, Pool, Scaleset
from .proxy import Proxy
from .repro import Repro
from .tasks.main import Task
from .workers.nodes import Node
from .workers.pools import Pool
from .workers.scalesets import Scaleset
update_objects: Dict[UpdateType, Type[ORMMixin]] = {
UpdateType.Task: Task,

View File

@ -0,0 +1,432 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import List, Optional, Tuple
from uuid import UUID
from onefuzztypes.enums import ErrorCode, NodeState
from onefuzztypes.events import (
EventNodeCreated,
EventNodeDeleted,
EventNodeStateUpdated,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Node as BASE_NODE
from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey
from onefuzztypes.models import NodeTasks as BASE_NODE_TASK
from onefuzztypes.models import Result, StopNodeCommand
from onefuzztypes.primitives import PoolName
from pydantic import Field
from ..__version__ import __version__
from ..azure.vmss import get_instance_id
from ..events import send_event
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1)
NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7)
# Future work:
#
# Enabling autoscaling for the scalesets based on the pool work queues.
# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
class Node(BASE_NODE, ORMMixin):
# should only be set by Scaleset.reimage_nodes
# should only be unset during agent_registration POST
reimage_queued: bool = Field(default=False)
@classmethod
def create(
cls,
*,
pool_name: PoolName,
machine_id: UUID,
scaleset_id: Optional[UUID],
version: str,
) -> "Node":
node = cls(
pool_name=pool_name,
machine_id=machine_id,
scaleset_id=scaleset_id,
version=version,
)
node.save()
send_event(
EventNodeCreated(
machine_id=node.machine_id,
scaleset_id=node.scaleset_id,
pool_name=node.pool_name,
)
)
return node
@classmethod
def search_states(
cls,
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
) -> List["Node"]:
query: QueryFilter = {}
if scaleset_id:
query["scaleset_id"] = [scaleset_id]
if states:
query["state"] = states
if pool_name:
query["pool_name"] = [pool_name]
return cls.search(query=query)
@classmethod
def search_outdated(
cls,
*,
scaleset_id: Optional[UUID] = None,
states: Optional[List[NodeState]] = None,
pool_name: Optional[str] = None,
exclude_update_scheduled: bool = False,
num_results: Optional[int] = None,
) -> List["Node"]:
query: QueryFilter = {}
if scaleset_id:
query["scaleset_id"] = [scaleset_id]
if states:
query["state"] = states
if pool_name:
query["pool_name"] = [pool_name]
if exclude_update_scheduled:
query["reimage_requested"] = [False]
query["delete_requested"] = [False]
# azure table query always return false when the column does not exist
# We write the query this way to allow us to get the nodes where the
# version is not defined as well as the nodes with a mismatched version
version_query = "not (version eq '%s')" % __version__
return cls.search(
query=query, raw_unchecked_filter=version_query, num_results=num_results
)
@classmethod
def mark_outdated_nodes(cls) -> None:
# ony update 500 nodes at a time to mitigate timeout issues
outdated = cls.search_outdated(exclude_update_scheduled=True, num_results=500)
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.to_reimage(done=True)
else:
node.to_reimage()
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]})
if not nodes:
return None
if len(nodes) != 1:
return None
return nodes[0]
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "machine_id")
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"tasks": ...}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
"state": ...,
"scaleset_id": ...,
}
def scaleset_node_exists(self) -> bool:
if self.scaleset_id is None:
return False
from .scalesets import Scaleset
scaleset = Scaleset.get_by_id(self.scaleset_id)
if not isinstance(scaleset, Scaleset):
return False
instance_id = get_instance_id(scaleset.scaleset_id, self.machine_id)
return isinstance(instance_id, str)
@classmethod
def stop_task(cls, task_id: UUID) -> None:
# For now, this just re-images the node. Eventually, this
# should send a message to the node to let the agent shut down
# gracefully
nodes = NodeTasks.get_nodes_by_task_id(task_id)
for node in nodes:
if node.state not in NodeState.ready_for_reset():
logging.info(
"stopping machine_id:%s running task:%s",
node.machine_id,
task_id,
)
node.stop()
def mark_tasks_stopped_early(self) -> None:
from ..tasks.main import Task
for entry in NodeTasks.get_by_machine_id(self.machine_id):
task = Task.get_by_task_id(entry.task_id)
if isinstance(task, Task):
task.mark_failed(
Error(
code=ErrorCode.TASK_FAILED,
errors=["node reimaged during task execution"],
)
)
def could_shrink_scaleset(self) -> bool:
from .scalesets import ScalesetShrinkQueue
if self.scaleset_id and ScalesetShrinkQueue(self.scaleset_id).should_shrink():
return True
return False
def can_process_new_work(self) -> bool:
if self.is_outdated():
logging.info(
"can_schedule old version machine_id:%s version:%s",
self.machine_id,
self.version,
)
self.stop()
return False
if self.state in NodeState.ready_for_reset():
logging.info(
"can_schedule node is set for reset. machine_id:%s", self.machine_id
)
return False
if self.delete_requested:
logging.info(
"can_schedule is set to be deleted. machine_id:%s",
self.machine_id,
)
self.stop()
return False
if self.reimage_requested:
logging.info(
"can_schedule is set to be reimaged. machine_id:%s",
self.machine_id,
)
self.stop()
return False
if self.could_shrink_scaleset():
self.set_halt()
logging.info("node scheduled to shrink. machine_id:%s", self.machine_id)
return False
return True
def is_outdated(self) -> bool:
return self.version != __version__
def send_message(self, message: NodeCommand) -> None:
NodeMessage(
machine_id=self.machine_id,
message=message,
).save()
def to_reimage(self, done: bool = False) -> None:
if done:
if self.state not in NodeState.ready_for_reset():
self.state = NodeState.done
if not self.reimage_requested and not self.delete_requested:
logging.info("setting reimage_requested: %s", self.machine_id)
self.reimage_requested = True
self.save()
def add_ssh_public_key(self, public_key: str) -> Result[None]:
if self.scaleset_id is None:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["only able to add ssh keys to scaleset nodes"],
)
if not public_key.endswith("\n"):
public_key += "\n"
self.send_message(
NodeCommand(add_ssh_key=NodeCommandAddSshKey(public_key=public_key))
)
return None
def stop(self) -> None:
self.to_reimage()
self.send_message(NodeCommand(stop=StopNodeCommand()))
def set_shutdown(self) -> None:
# don't give out more work to the node, but let it finish existing work
logging.info("setting delete_requested: %s", self.machine_id)
self.delete_requested = True
self.save()
def set_halt(self) -> None:
""" Tell the node to stop everything. """
self.set_shutdown()
self.stop()
self.set_state(NodeState.halt)
@classmethod
def get_dead_nodes(
cls, scaleset_id: UUID, expiration_period: datetime.timedelta
) -> List["Node"]:
time_filter = "heartbeat lt datetime'%s'" % (
(datetime.datetime.utcnow() - expiration_period).isoformat()
)
return cls.search(
query={"scaleset_id": [scaleset_id]},
raw_unchecked_filter=time_filter,
)
@classmethod
def reimage_long_lived_nodes(cls, scaleset_id: UUID) -> None:
"""
Mark any excessively long lived node to be re-imaged.
This helps keep nodes on scalesets that use `latest` OS image SKUs
reasonably up-to-date with OS patches without disrupting running
fuzzing tasks with patch reboot cycles.
"""
time_filter = "Timestamp lt datetime'%s'" % (
(datetime.datetime.utcnow() - NODE_REIMAGE_TIME).isoformat()
)
# skip any nodes already marked for reimage/deletion
for node in cls.search(
query={
"scaleset_id": [scaleset_id],
"reimage_requested": [False],
"delete_requested": [False],
},
raw_unchecked_filter=time_filter,
):
node.to_reimage()
def set_state(self, state: NodeState) -> None:
if self.state != state:
self.state = state
send_event(
EventNodeStateUpdated(
machine_id=self.machine_id,
pool_name=self.pool_name,
scaleset_id=self.scaleset_id,
state=state,
)
)
self.save()
def delete(self) -> None:
self.mark_tasks_stopped_early()
NodeTasks.clear_by_machine_id(self.machine_id)
NodeMessage.clear_messages(self.machine_id)
super().delete()
send_event(
EventNodeDeleted(
machine_id=self.machine_id,
pool_name=self.pool_name,
scaleset_id=self.scaleset_id,
)
)
class NodeTasks(BASE_NODE_TASK, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("machine_id", "task_id")
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
"task_id": ...,
"state": ...,
}
@classmethod
def get_nodes_by_task_id(cls, task_id: UUID) -> List["Node"]:
result = []
for entry in cls.search(query={"task_id": [task_id]}):
node = Node.get_by_machine_id(entry.machine_id)
if node:
result.append(node)
return result
@classmethod
def get_node_assignments(cls, task_id: UUID) -> List[NodeAssignment]:
result = []
for entry in cls.search(query={"task_id": [task_id]}):
node = Node.get_by_machine_id(entry.machine_id)
if node:
node_assignment = NodeAssignment(
node_id=node.machine_id,
scaleset_id=node.scaleset_id,
state=entry.state,
)
result.append(node_assignment)
return result
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> List["NodeTasks"]:
return cls.search(query={"machine_id": [machine_id]})
@classmethod
def get_by_task_id(cls, task_id: UUID) -> List["NodeTasks"]:
return cls.search(query={"task_id": [task_id]})
@classmethod
def clear_by_machine_id(cls, machine_id: UUID) -> None:
logging.info("clearing tasks for node: %s", machine_id)
for entry in cls.get_by_machine_id(machine_id):
entry.delete()
# this isn't anticipated to be needed by the client, hence it not
# being in onefuzztypes
class NodeMessage(ORMMixin):
machine_id: UUID
message_id: str = Field(default_factory=datetime.datetime.utcnow().timestamp)
message: NodeCommand
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("machine_id", "message_id")
@classmethod
def get_messages(
cls, machine_id: UUID, num_results: int = None
) -> List["NodeMessage"]:
entries: List["NodeMessage"] = cls.search(
query={"machine_id": [machine_id]}, num_results=num_results
)
return entries
@classmethod
def clear_messages(cls, machine_id: UUID) -> None:
logging.info("clearing messages for node: %s", machine_id)
messages = cls.get_messages(machine_id)
for message in messages:
message.delete()

View File

@ -0,0 +1,239 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import OS, Architecture, ErrorCode, PoolState, ScalesetState
from onefuzztypes.events import EventPoolCreated, EventPoolDeleted
from onefuzztypes.models import AutoScaleConfig, Error
from onefuzztypes.models import Pool as BASE_POOL
from onefuzztypes.models import (
ScalesetSummary,
WorkSet,
WorkSetSummary,
WorkUnitSummary,
)
from onefuzztypes.primitives import PoolName
from ..azure.queue import create_queue, delete_queue, peek_queue, queue_object
from ..azure.storage import StorageType
from ..events import send_event
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1)
NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7)
# Future work:
#
# Enabling autoscaling for the scalesets based on the pool work queues.
# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
class Pool(BASE_POOL, ORMMixin):
@classmethod
def create(
cls,
*,
name: PoolName,
os: OS,
arch: Architecture,
managed: bool,
client_id: Optional[UUID],
autoscale: Optional[AutoScaleConfig],
) -> "Pool":
entry = cls(
name=name,
os=os,
arch=arch,
managed=managed,
client_id=client_id,
config=None,
autoscale=autoscale,
)
entry.save()
send_event(
EventPoolCreated(
pool_name=name,
os=os,
arch=arch,
managed=managed,
autoscale=autoscale,
)
)
return entry
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {
"nodes": ...,
"queue": ...,
"work_queue": ...,
"config": ...,
"node_summary": ...,
}
def export_exclude(self) -> Optional[MappingIntStrAny]:
return {
"etag": ...,
"timestamp": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_id": ...,
"os": ...,
"state": ...,
"managed": ...,
}
def populate_scaleset_summary(self) -> None:
from .scalesets import Scaleset
self.scaleset_summary = [
ScalesetSummary(scaleset_id=x.scaleset_id, state=x.state)
for x in Scaleset.search_by_pool(self.name)
]
def populate_work_queue(self) -> None:
self.work_queue = []
# Only populate the work queue summaries if the pool is initialized. We
# can then be sure that the queue is available in the operations below.
if self.state == PoolState.init:
return
worksets = peek_queue(
self.get_pool_queue(), StorageType.corpus, object_type=WorkSet
)
for workset in worksets:
work_units = [
WorkUnitSummary(
job_id=work_unit.job_id,
task_id=work_unit.task_id,
task_type=work_unit.task_type,
)
for work_unit in workset.work_units
]
self.work_queue.append(WorkSetSummary(work_units=work_units))
def get_pool_queue(self) -> str:
return "pool-%s" % self.pool_id.hex
def init(self) -> None:
create_queue(self.get_pool_queue(), StorageType.corpus)
self.state = PoolState.running
self.save()
def schedule_workset(self, work_set: WorkSet) -> bool:
# Don't schedule work for pools that can't and won't do work.
if self.state in [PoolState.shutdown, PoolState.halt]:
return False
return queue_object(
self.get_pool_queue(),
work_set,
StorageType.corpus,
)
@classmethod
def get_by_id(cls, pool_id: UUID) -> Union[Error, "Pool"]:
pools = cls.search(query={"pool_id": [pool_id]})
if not pools:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])
if len(pools) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
)
pool = pools[0]
return pool
@classmethod
def get_by_name(cls, name: PoolName) -> Union[Error, "Pool"]:
pools = cls.search(query={"name": [name]})
if not pools:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find pool"])
if len(pools) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying pool"]
)
pool = pools[0]
return pool
@classmethod
def search_states(cls, *, states: Optional[List[PoolState]] = None) -> List["Pool"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
def set_shutdown(self, now: bool) -> None:
if self.state in [PoolState.halt, PoolState.shutdown]:
return
if now:
self.state = PoolState.halt
else:
self.state = PoolState.shutdown
self.save()
def shutdown(self) -> None:
""" shutdown allows nodes to finish current work then delete """
from .nodes import Node
from .scalesets import Scaleset
scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes:
logging.info("pool stopped, deleting: %s", self.name)
self.state = PoolState.halt
self.delete()
return
for scaleset in scalesets:
scaleset.set_shutdown(now=False)
for node in nodes:
node.set_shutdown()
self.save()
def halt(self) -> None:
""" halt the pool immediately """
from .nodes import Node
from .scalesets import Scaleset
scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes:
delete_queue(self.get_pool_queue(), StorageType.corpus)
logging.info("pool stopped, deleting: %s", self.name)
self.state = PoolState.halt
self.delete()
return
for scaleset in scalesets:
scaleset.state = ScalesetState.halt
scaleset.save()
for node in nodes:
node.set_halt()
self.save()
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("name", "pool_id")
def delete(self) -> None:
super().delete()
send_event(EventPoolDeleted(pool_name=self.name))

View File

@ -0,0 +1,601 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID, uuid4
from onefuzztypes.enums import ErrorCode, NodeState, PoolState, ScalesetState
from onefuzztypes.events import (
EventScalesetCreated,
EventScalesetDeleted,
EventScalesetFailed,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Scaleset as BASE_SCALESET
from onefuzztypes.models import ScalesetNodeState
from onefuzztypes.primitives import PoolName, Region
from pydantic import BaseModel, Field
from ..azure.auth import build_auth
from ..azure.image import get_os
from ..azure.network import Network
from ..azure.queue import (
clear_queue,
create_queue,
delete_queue,
queue_object,
remove_first_message,
)
from ..azure.storage import StorageType
from ..azure.vmss import (
UnableToUpdate,
create_vmss,
delete_vmss,
delete_vmss_nodes,
get_vmss,
get_vmss_size,
list_instance_ids,
reimage_vmss_nodes,
resize_vmss,
update_extensions,
)
from ..events import send_event
from ..extension import fuzz_extensions
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from .nodes import Node
NODE_EXPIRATION_TIME: datetime.timedelta = datetime.timedelta(hours=1)
NODE_REIMAGE_TIME: datetime.timedelta = datetime.timedelta(days=7)
# Future work:
#
# Enabling autoscaling for the scalesets based on the pool work queues.
# https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
class Scaleset(BASE_SCALESET, ORMMixin):
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"nodes": ...}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"scaleset_id": ...,
"os": ...,
"vm_sku": ...,
"size": ...,
"spot_instances": ...,
}
@classmethod
def create(
cls,
*,
pool_name: PoolName,
vm_sku: str,
image: str,
region: Region,
size: int,
spot_instances: bool,
tags: Dict[str, str],
client_id: Optional[UUID] = None,
client_object_id: Optional[UUID] = None,
) -> "Scaleset":
entry = cls(
pool_name=pool_name,
vm_sku=vm_sku,
image=image,
region=region,
size=size,
spot_instances=spot_instances,
auth=build_auth(),
client_id=client_id,
client_object_id=client_object_id,
tags=tags,
)
entry.save()
send_event(
EventScalesetCreated(
scaleset_id=entry.scaleset_id,
pool_name=entry.pool_name,
vm_sku=vm_sku,
image=image,
region=region,
size=size,
)
)
return entry
@classmethod
def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]:
return cls.search(query={"pool_name": [pool_name]})
@classmethod
def get_by_id(cls, scaleset_id: UUID) -> Union[Error, "Scaleset"]:
scalesets = cls.search(query={"scaleset_id": [scaleset_id]})
if not scalesets:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["unable to find scaleset"]
)
if len(scalesets) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying scaleset"]
)
scaleset = scalesets[0]
return scaleset
@classmethod
def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]:
return cls.search(query={"client_object_id": [object_id]})
def set_failed(self, error: Error) -> None:
if self.error is not None:
return
self.error = error
self.state = ScalesetState.creation_failed
self.save()
send_event(
EventScalesetFailed(
scaleset_id=self.scaleset_id, pool_name=self.pool_name, error=self.error
)
)
def init(self) -> None:
from .pools import Pool
logging.info("scaleset init: %s", self.scaleset_id)
ScalesetShrinkQueue(self.scaleset_id).create()
# Handle the race condition between a pool being deleted and a
# scaleset being added to the pool.
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.set_failed(pool)
return
if pool.state == PoolState.init:
logging.info(
"scaleset waiting for pool: %s - %s", self.pool_name, self.scaleset_id
)
elif pool.state == PoolState.running:
image_os = get_os(self.region, self.image)
if isinstance(image_os, Error):
self.set_failed(image_os)
return
elif image_os != pool.os:
error = Error(
code=ErrorCode.INVALID_REQUEST,
errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)],
)
self.set_failed(error)
return
else:
self.state = ScalesetState.setup
else:
self.state = ScalesetState.setup
self.save()
def setup(self) -> None:
from .pools import Pool
# TODO: How do we pass in SSH configs for Windows? Previously
# This was done as part of the generated per-task setup script.
logging.info("scaleset setup: %s", self.scaleset_id)
network = Network(self.region)
network_id = network.get_id()
if not network_id:
logging.info("creating network: %s", self.region)
result = network.create()
if isinstance(result, Error):
self.set_failed(result)
return
self.save()
return
if self.auth is None:
error = Error(
code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"]
)
self.set_failed(error)
return
vmss = get_vmss(self.scaleset_id)
if vmss is None:
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.set_failed(pool)
return
logging.info("creating scaleset: %s", self.scaleset_id)
extensions = fuzz_extensions(pool, self)
result = create_vmss(
self.region,
self.scaleset_id,
self.vm_sku,
self.size,
self.image,
network_id,
self.spot_instances,
extensions,
self.auth.password,
self.auth.public_key,
self.tags,
)
if isinstance(result, Error):
self.set_failed(result)
return
else:
logging.info("creating scaleset: %s", self.scaleset_id)
elif vmss.provisioning_state == "Creating":
logging.info("Waiting on scaleset creation: %s", self.scaleset_id)
self.try_set_identity(vmss)
else:
logging.info("scaleset running: %s", self.scaleset_id)
identity_result = self.try_set_identity(vmss)
if identity_result:
self.set_failed(identity_result)
return
else:
self.state = ScalesetState.running
self.save()
def try_set_identity(self, vmss: Any) -> Optional[Error]:
def get_error() -> Error:
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=[
"The scaleset is expected to have exactly 1 user assigned identity"
],
)
if self.client_object_id:
return None
if (
vmss.identity
and vmss.identity.user_assigned_identities
and (len(vmss.identity.user_assigned_identities) != 1)
):
return get_error()
user_assinged_identities = list(vmss.identity.user_assigned_identities.values())
if user_assinged_identities[0].principal_id:
self.client_object_id = user_assinged_identities[0].principal_id
return None
else:
return get_error()
# result = 'did I modify the scaleset in azure'
def cleanup_nodes(self) -> bool:
if self.state == ScalesetState.halt:
logging.info("halting scaleset: %s", self.scaleset_id)
self.halt()
return True
Node.reimage_long_lived_nodes(self.scaleset_id)
to_reimage = []
to_delete = []
# ground truth of existing nodes
azure_nodes = list_instance_ids(self.scaleset_id)
nodes = Node.search_states(scaleset_id=self.scaleset_id)
# Nodes do not exists in scalesets but in table due to unknown failure
for node in nodes:
if node.machine_id not in azure_nodes:
logging.info(
"no longer in scaleset: %s:%s", self.scaleset_id, node.machine_id
)
node.delete()
existing_nodes = [x for x in nodes if x.machine_id in azure_nodes]
nodes_to_reset = [
x for x in existing_nodes if x.state in NodeState.ready_for_reset()
]
for node in nodes_to_reset:
if node.delete_requested:
to_delete.append(node)
else:
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
node.set_halt()
to_delete.append(node)
elif not node.reimage_queued:
# only add nodes that are not already set to reschedule
to_reimage.append(node)
dead_nodes = Node.get_dead_nodes(self.scaleset_id, NODE_EXPIRATION_TIME)
for node in dead_nodes:
node.set_halt()
to_reimage.append(node)
# Perform operations until they fail due to scaleset getting locked
try:
if to_delete:
logging.info(
"deleting nodes: %s - count: %d", self.scaleset_id, len(to_delete)
)
self.delete_nodes(to_delete)
for node in to_delete:
node.set_halt()
if to_reimage:
self.reimage_nodes(to_reimage)
except UnableToUpdate:
logging.info("scaleset update already in progress: %s", self.scaleset_id)
return bool(to_reimage) or bool(to_delete)
def _resize_equal(self) -> None:
# NOTE: this is the only place we reset to the 'running' state.
# This ensures that our idea of scaleset size agrees with Azure
node_count = len(Node.search_states(scaleset_id=self.scaleset_id))
if node_count == self.size:
logging.info("resize finished: %s", self.scaleset_id)
self.state = ScalesetState.running
self.save()
return
else:
logging.info(
"resize is finished, waiting for nodes to check in: "
"%s (%d of %d nodes checked in)",
self.scaleset_id,
node_count,
self.size,
)
return
def _resize_grow(self) -> None:
try:
resize_vmss(self.scaleset_id, self.size)
except UnableToUpdate:
logging.info("scaleset is mid-operation already")
return
def _resize_shrink(self, to_remove: int) -> None:
queue = ScalesetShrinkQueue(self.scaleset_id)
for _ in range(to_remove):
queue.add_entry()
def resize(self) -> None:
# no longer needing to resize
if self.state != ScalesetState.resize:
return
logging.info("scaleset resize: %s - %s", self.scaleset_id, self.size)
# reset the node delete queue
ScalesetShrinkQueue(self.scaleset_id).clear()
# just in case, always ensure size is within max capacity
self.size = min(self.size, self.max_size())
# Treat Azure knowledge of the size of the scaleset as "ground truth"
size = get_vmss_size(self.scaleset_id)
if size is None:
logging.info("scaleset is unavailable: %s", self.scaleset_id)
return
if size == self.size:
self._resize_equal()
elif self.size > size:
self._resize_grow()
else:
self._resize_shrink(size - self.size)
def delete_nodes(self, nodes: List[Node]) -> None:
if not nodes:
logging.debug("no nodes to delete")
return
if self.state == ScalesetState.halt:
logging.debug("scaleset delete will delete node: %s", self.scaleset_id)
return
machine_ids = []
for node in nodes:
if node.debug_keep_node:
logging.warning(
"delete manually overridden %s:%s",
self.scaleset_id,
node.machine_id,
)
else:
machine_ids.append(node.machine_id)
logging.info("deleting %s:%s", self.scaleset_id, machine_ids)
delete_vmss_nodes(self.scaleset_id, machine_ids)
def reimage_nodes(self, nodes: List[Node]) -> None:
if not nodes:
logging.debug("no nodes to reimage")
return
if self.state == ScalesetState.shutdown:
self.delete_nodes(nodes)
return
if self.state == ScalesetState.halt:
logging.debug("scaleset delete will delete node: %s", self.scaleset_id)
return
machine_ids = []
for node in nodes:
if node.debug_keep_node:
logging.warning(
"reimage manually overridden %s:%s",
self.scaleset_id,
node.machine_id,
)
else:
machine_ids.append(node.machine_id)
result = reimage_vmss_nodes(self.scaleset_id, machine_ids)
if isinstance(result, Error):
raise Exception(
"unable to reimage nodes: %s:%s - %s"
% (self.scaleset_id, machine_ids, result)
)
for node in nodes:
node.reimage_queued = True
node.save()
def set_shutdown(self, now: bool) -> None:
if self.state in [ScalesetState.halt, ScalesetState.shutdown]:
return
if now:
self.state = ScalesetState.halt
else:
self.state = ScalesetState.shutdown
self.save()
def shutdown(self) -> None:
size = get_vmss_size(self.scaleset_id)
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
nodes = Node.search_states(scaleset_id=self.scaleset_id)
for node in nodes:
node.set_shutdown()
if size is None or size == 0:
self.halt()
def halt(self) -> None:
ScalesetShrinkQueue(self.scaleset_id).delete()
for node in Node.search_states(scaleset_id=self.scaleset_id):
logging.info("deleting node %s:%s", self.scaleset_id, node.machine_id)
node.delete()
logging.info("scaleset delete starting: %s", self.scaleset_id)
if delete_vmss(self.scaleset_id):
logging.info("scaleset deleted: %s", self.scaleset_id)
self.delete()
else:
self.save()
@classmethod
def scaleset_max_size(cls, image: str) -> int:
# https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/
# virtual-machine-scale-sets-placement-groups#checklist-for-using-large-scale-sets
if image.startswith("/"):
return 600
else:
return 1000
def max_size(self) -> int:
return Scaleset.scaleset_max_size(self.image)
@classmethod
def search_states(
cls, *, states: Optional[List[ScalesetState]] = None
) -> List["Scaleset"]:
query: QueryFilter = {}
if states:
query["state"] = states
return cls.search(query=query)
def update_nodes(self) -> None:
# Be in at-least 'setup' before checking for the list of VMs
if self.state == ScalesetState.init:
return
nodes = Node.search_states(scaleset_id=self.scaleset_id)
azure_nodes = list_instance_ids(self.scaleset_id)
self.nodes = []
for (machine_id, instance_id) in azure_nodes.items():
node_state: Optional[ScalesetNodeState] = None
for node in nodes:
if node.machine_id == machine_id:
node_state = ScalesetNodeState(
machine_id=machine_id,
instance_id=instance_id,
state=node.state,
)
break
if not node_state:
node_state = ScalesetNodeState(
machine_id=machine_id,
instance_id=instance_id,
)
self.nodes.append(node_state)
def update_configs(self) -> None:
from .pools import Pool
if not self.needs_config_update:
logging.debug("config update not needed: %s", self.scaleset_id)
logging.info("updating scaleset configs: %s", self.scaleset_id)
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
logging.error(
"unable to find pool during config update: %s - %s",
self.scaleset_id,
pool,
)
self.set_failed(pool)
return
extensions = fuzz_extensions(pool, self)
try:
update_extensions(self.scaleset_id, extensions)
self.needs_config_update = False
self.save()
except UnableToUpdate:
logging.debug(
"unable to update configs, update already in progress: %s",
self.scaleset_id,
)
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "scaleset_id")
def delete(self) -> None:
super().delete()
send_event(
EventScalesetDeleted(scaleset_id=self.scaleset_id, pool_name=self.pool_name)
)
class ShrinkEntry(BaseModel):
shrink_id: UUID = Field(default_factory=uuid4)
class ScalesetShrinkQueue:
def __init__(self, scaleset_id: UUID):
self.scaleset_id = scaleset_id
def queue_name(self) -> str:
return "to-shrink-%s" % self.scaleset_id.hex
def clear(self) -> None:
clear_queue(self.queue_name(), StorageType.config)
def create(self) -> None:
create_queue(self.queue_name(), StorageType.config)
def delete(self) -> None:
delete_queue(self.queue_name(), StorageType.config)
def add_entry(self) -> None:
queue_object(self.queue_name(), ShrinkEntry(), StorageType.config)
def should_shrink(self) -> bool:
return remove_first_message(self.queue_name(), StorageType.config)

View File

@ -23,8 +23,8 @@ from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.azure.vmss import list_available_skus
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Pool
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.pools import Pool
def set_config(pool: Pool) -> Pool:

View File

@ -13,10 +13,10 @@ from onefuzztypes.responses import BoolResult, ProxyGetResult
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Scaleset
from ..onefuzzlib.proxy import Proxy
from ..onefuzzlib.proxy_forward import ProxyForward
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.scalesets import Scaleset
def get_result(proxy_forward: ProxyForward, proxy: Optional[Proxy]) -> ProxyGetResult:

View File

@ -12,7 +12,7 @@ from onefuzztypes.models import NodeHeartbeatEntry
from pydantic import ValidationError
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Node
from ..onefuzzlib.workers.nodes import Node
def main(msg: func.QueueMessage, dashboard: func.Out[str]) -> None:

View File

@ -18,8 +18,9 @@ from ..onefuzzlib.azure.creds import get_base_region, get_regions
from ..onefuzzlib.azure.vmss import list_available_skus
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Pool, Scaleset
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.workers.pools import Pool
from ..onefuzzlib.workers.scalesets import Scaleset
def get(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -12,12 +12,12 @@ from onefuzztypes.responses import BoolResult
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.events import get_events
from ..onefuzzlib.jobs import Job
from ..onefuzzlib.pools import NodeTasks
from ..onefuzzlib.request import not_ok, ok, parse_request
from ..onefuzzlib.task_event import TaskEvent
from ..onefuzzlib.tasks.config import TaskConfigError, check_config
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.user_credentials import parse_jwt_token
from ..onefuzzlib.workers.nodes import NodeTasks
def post(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -9,9 +9,9 @@ import azure.functions as func
from onefuzztypes.enums import VmState
from ..onefuzzlib.events import get_events
from ..onefuzzlib.pools import Scaleset
from ..onefuzzlib.proxy import Proxy
from ..onefuzzlib.webhooks import WebhookMessageLog
from ..onefuzzlib.workers.scalesets import Scaleset
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841

View File

@ -11,7 +11,9 @@ from onefuzztypes.enums import NodeState, PoolState
from ..onefuzzlib.autoscale import autoscale_pool
from ..onefuzzlib.events import get_events
from ..onefuzzlib.orm import process_state_updates
from ..onefuzzlib.pools import Node, Pool, Scaleset
from ..onefuzzlib.workers.nodes import Node
from ..onefuzzlib.workers.pools import Pool
from ..onefuzzlib.workers.scalesets import Scaleset
def process_scaleset(scaleset: Scaleset) -> None:

View File

@ -12,8 +12,8 @@ from onefuzztypes.models import TaskConfig, TaskContainers, TaskDetails, TaskPoo
from onefuzztypes.primitives import Container, PoolName
from __app__.onefuzzlib.autoscale import autoscale_pool, get_vm_count
from __app__.onefuzzlib.pools import Pool
from __app__.onefuzzlib.tasks.main import Task
from __app__.onefuzzlib.workers.pools import Pool
class TestAutoscale(unittest.TestCase):