Remove update_event as a single event loop for the system (#160)

This commit is contained in:
bmc-msft
2020-10-16 21:42:35 -04:00
committed by GitHub
parent 9fa25803ab
commit 75f29b9f2e
24 changed files with 418 additions and 324 deletions

View File

@ -0,0 +1,148 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import math
from typing import List
from onefuzztypes.enums import NodeState, ScalesetState
from onefuzztypes.models import AutoScaleConfig, TaskPool
from .pools import Node, Pool, Scaleset
from .tasks.main import Task
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
logging.info("Scaling up")
autoscale_config = pool.autoscale
if not isinstance(autoscale_config, AutoScaleConfig):
return
for scaleset in scalesets:
if scaleset.state in [ScalesetState.running, ScalesetState.resize]:
max_size = min(scaleset.max_size(), autoscale_config.scaleset_size)
logging.info(
"scaleset:%s size:%d max_size:%d"
% (scaleset.scaleset_id, scaleset.size, max_size)
)
if scaleset.size < max_size:
current_size = scaleset.size
if nodes_needed <= max_size - current_size:
scaleset.size = current_size + nodes_needed
nodes_needed = 0
else:
scaleset.size = max_size
nodes_needed = nodes_needed - (max_size - current_size)
scaleset.state = ScalesetState.resize
scaleset.save()
else:
continue
if nodes_needed == 0:
return
for _ in range(
math.ceil(
nodes_needed
/ min(
Scaleset.scaleset_max_size(autoscale_config.image),
autoscale_config.scaleset_size,
)
)
):
logging.info("Creating Scaleset for Pool %s" % (pool.name))
max_nodes_scaleset = min(
Scaleset.scaleset_max_size(autoscale_config.image),
autoscale_config.scaleset_size,
nodes_needed,
)
if not autoscale_config.region:
raise Exception("Region is missing")
scaleset = Scaleset.create(
pool_name=pool.name,
vm_sku=autoscale_config.vm_sku,
image=autoscale_config.image,
region=autoscale_config.region,
size=max_nodes_scaleset,
spot_instances=autoscale_config.spot_instances,
tags={"pool": pool.name},
)
scaleset.save()
nodes_needed -= max_nodes_scaleset
def scale_down(scalesets: List[Scaleset], nodes_to_remove: int) -> None:
logging.info("Scaling down")
for scaleset in scalesets:
nodes = Node.search_states(
scaleset_id=scaleset.scaleset_id, states=[NodeState.free]
)
if nodes and nodes_to_remove > 0:
max_nodes_remove = min(len(nodes), nodes_to_remove)
if max_nodes_remove >= scaleset.size and len(nodes) == scaleset.size:
scaleset.state = ScalesetState.shutdown
nodes_to_remove = nodes_to_remove - scaleset.size
scaleset.save()
for node in nodes:
node.set_shutdown()
continue
scaleset.size = scaleset.size - max_nodes_remove
nodes_to_remove = nodes_to_remove - max_nodes_remove
scaleset.state = ScalesetState.resize
scaleset.save()
def get_vm_count(tasks: List[Task]) -> int:
count = 0
for task in tasks:
task_pool = task.get_pool()
if (
not task_pool
or not isinstance(task_pool, Pool)
or not isinstance(task.config.pool, TaskPool)
):
continue
count += task.config.pool.count
return count
def autoscale_pool(pool: Pool) -> None:
logging.info("autoscale: %s" % (pool.autoscale))
if not pool.autoscale:
return
# get all the tasks (count not stopped) for the pool
tasks = Task.get_tasks_by_pool_name(pool.name)
logging.info("Pool: %s, #Tasks %d" % (pool.name, len(tasks)))
num_of_tasks = get_vm_count(tasks)
nodes_needed = max(num_of_tasks, pool.autoscale.min_size)
if pool.autoscale.max_size:
nodes_needed = min(nodes_needed, pool.autoscale.max_size)
# do scaleset logic match with pool
# get all the scalesets for the pool
scalesets = Scaleset.search_by_pool(pool.name)
pool_resize = False
for scaleset in scalesets:
if scaleset.state in ScalesetState.modifying():
pool_resize = True
break
nodes_needed = nodes_needed - scaleset.size
if pool_resize:
return
logging.info("Pool: %s, #Nodes Needed: %d" % (pool.name, nodes_needed))
if nodes_needed > 0:
# resizing scaleset or creating new scaleset.
scale_up(pool, scalesets, nodes_needed)
elif nodes_needed < 0:
scale_down(scalesets, abs(nodes_needed))

View File

@ -73,9 +73,6 @@ class Job(BASE_JOB, ORMMixin):
self.state = JobState.stopped
self.save()
def queue_stop(self) -> None:
self.queue(method=self.stopping)
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:

View File

@ -36,6 +36,7 @@ from onefuzztypes.enums import (
from onefuzztypes.models import Error
from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel, Field
from typing_extensions import Protocol
from .azure.table import get_client
from .dashboard import add_event
@ -66,6 +67,36 @@ KEY = Union[int, str, UUID, Enum]
HOURS = 60 * 60
class HasState(Protocol):
# TODO: this should be bound tighter than Any
# In the end, we want this to be an Enum. Specifically, one of
# the JobState,TaskState,etc enums.
state: Any
def process_state_update(obj: HasState) -> None:
"""
process a single state update, if the obj
implements a function for that state
"""
func = getattr(obj, obj.state.name, None)
if func is None:
return
func()
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
""" process through the state machine for an object """
for _ in range(max_updates):
state = obj.state
process_state_update(obj)
new_state = obj.state
if new_state == state:
break
def resolve(key: KEY) -> str:
if isinstance(key, str):
return key

View File

@ -69,6 +69,10 @@ from .orm import MappingIntStrAny, ORMMixin, QueryFilter
class Node(BASE_NODE, ORMMixin):
# should only be set by Scaleset.reimage_nodes
# should only be unset during agent_registration POST
reimage_queued: bool = Field(default=False)
@classmethod
def search_states(
cls,
@ -108,6 +112,21 @@ class Node(BASE_NODE, ORMMixin):
version_query = "not (version eq '%s')" % __version__
return cls.search(query=query, raw_unchecked_filter=version_query)
@classmethod
def mark_outdated_nodes(cls) -> None:
outdated = cls.search_outdated()
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.to_reimage(done=True)
else:
node.to_reimage()
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]})
@ -195,9 +214,24 @@ class Node(BASE_NODE, ORMMixin):
self.stop()
return False
if self.delete_requested or self.reimage_requested:
if self.state in NodeState.ready_for_reset():
logging.info(
"can_schedule should be recycled. machine_id:%s", self.machine_id
"can_schedule node is set for reset. machine_id:%s", self.machine_id
)
return False
if self.delete_requested:
logging.info(
"can_schedule is set to be deleted. machine_id:%s",
self.machine_id,
)
self.stop()
return False
if self.reimage_requested:
logging.info(
"can_schedule is set to be reimaged. machine_id:%s",
self.machine_id,
)
self.stop()
return False
@ -682,25 +716,11 @@ class Scaleset(BASE_SCALESET, ORMMixin):
to_reimage = []
to_delete = []
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.state = NodeState.done
to_reimage.append(node)
else:
node.to_reimage()
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
)
if not outdated and not nodes:
if not nodes:
logging.info("no nodes need updating: %s", self.scaleset_id)
return False
@ -719,7 +739,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
node.set_halt()
to_delete.append(node)
else:
elif not node.reimage_queued:
# only add nodes that are not already set to reschedule
to_reimage.append(node)
# Perform operations until they fail due to scaleset getting locked
@ -833,6 +854,9 @@ class Scaleset(BASE_SCALESET, ORMMixin):
"unable to reimage nodes: %s:%s - %s"
% (self.scaleset_id, machine_ids, result)
)
for node in nodes:
node.reimage_queued = True
node.save()
def shutdown(self) -> None:
size = get_vmss_size(self.scaleset_id)
@ -855,7 +879,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.save()
else:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
@classmethod

View File

@ -27,7 +27,7 @@ from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas
from .azure.vm import VM
from .extension import proxy_manager_extensions
from .orm import HOURS, MappingIntStrAny, ORMMixin, QueryFilter
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
from .proxy_forward import ProxyForward
PROXY_SKU = "Standard_B2s"
@ -210,9 +210,6 @@ class Proxy(ORMMixin):
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
)
def queue_stop(self, count: int) -> None:
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
query: QueryFilter = {}

View File

@ -4,6 +4,7 @@
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Union
from azure.mgmt.compute.models import VirtualMachine
@ -18,7 +19,7 @@ from .azure.creds import get_base_region, get_func_storage
from .azure.ip import get_public_ip
from .azure.vm import VM
from .extension import repro_extensions
from .orm import HOURS, ORMMixin, QueryFilter
from .orm import ORMMixin, QueryFilter
from .reports import get_report
from .tasks.main import Task
@ -205,9 +206,6 @@ class Repro(BASE_REPRO, ORMMixin):
logging.info("saved repro script")
return None
def queue_stop(self, count: int) -> None:
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
query: QueryFilter = {}
@ -228,10 +226,18 @@ class Repro(BASE_REPRO, ORMMixin):
return task
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
if vm.end_time is None:
vm.end_time = datetime.utcnow() + timedelta(hours=config.duration)
vm.save()
vm.queue_stop(config.duration)
return vm
@classmethod
def search_expired(cls) -> List["Repro"]:
# unlike jobs/tasks, the entry is deleted from the backing table upon stop
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(raw_unchecked_filter=time_filter)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("vm_id", None)

View File

@ -118,9 +118,6 @@ class Task(BASE_TASK, ORMMixin):
self.state = TaskState.stopped
self.save()
def queue_stop(self) -> None:
self.queue(method=self.stopping)
@classmethod
def search_states(
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
@ -165,7 +162,7 @@ class Task(BASE_TASK, ORMMixin):
task_pool = task.get_pool()
if not task_pool:
continue
if pool_name == task_pool.name and task.state in TaskState.available():
if pool_name == task_pool.name:
pool_tasks.append(task)
return pool_tasks

View File

@ -7,17 +7,34 @@ import logging
from typing import Dict, List
from uuid import UUID
from onefuzztypes.enums import OS, TaskState
from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit
from ..azure.containers import blob_exists, get_container_sas_url, save_blob
from ..azure.creds import get_func_storage
from ..pools import Pool
from .config import build_task_config, get_setup_container
from .main import Task
HOURS = 60 * 60
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
if pool.state not in PoolState.available():
logging.info(
"pool not available for work: %s state: %s", pool.name, pool.state.name
)
return False
for _ in range(count):
if not pool.schedule_workset(workset):
logging.error(
"unable to schedule workset. pool:%s workset:%s", pool.name, workset
)
return False
return True
def schedule_tasks() -> None:
to_schedule: Dict[UUID, List[Task]] = {}
@ -82,7 +99,7 @@ def schedule_tasks() -> None:
)
# For now, only offer singleton work sets.
work_set = WorkSet(
workset = WorkSet(
reboot=reboot,
script=(setup_script is not None),
setup_url=setup_url,
@ -94,7 +111,6 @@ def schedule_tasks() -> None:
logging.info("unable to find pool for task: %s", task.task_id)
continue
for _ in range(count):
pool.schedule_workset(work_set)
task.state = TaskState.scheduled
task.save()
if schedule_workset(workset, pool, count):
task.state = TaskState.scheduled
task.save()