mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-18 12:48:07 +00:00
Remove update_event as a single event loop for the system (#160)
This commit is contained in:
148
src/api-service/__app__/onefuzzlib/autoscale.py
Normal file
148
src/api-service/__app__/onefuzzlib/autoscale.py
Normal file
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from onefuzztypes.enums import NodeState, ScalesetState
|
||||
from onefuzztypes.models import AutoScaleConfig, TaskPool
|
||||
|
||||
from .pools import Node, Pool, Scaleset
|
||||
from .tasks.main import Task
|
||||
|
||||
|
||||
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
|
||||
logging.info("Scaling up")
|
||||
autoscale_config = pool.autoscale
|
||||
if not isinstance(autoscale_config, AutoScaleConfig):
|
||||
return
|
||||
|
||||
for scaleset in scalesets:
|
||||
if scaleset.state in [ScalesetState.running, ScalesetState.resize]:
|
||||
|
||||
max_size = min(scaleset.max_size(), autoscale_config.scaleset_size)
|
||||
logging.info(
|
||||
"scaleset:%s size:%d max_size:%d"
|
||||
% (scaleset.scaleset_id, scaleset.size, max_size)
|
||||
)
|
||||
if scaleset.size < max_size:
|
||||
current_size = scaleset.size
|
||||
if nodes_needed <= max_size - current_size:
|
||||
scaleset.size = current_size + nodes_needed
|
||||
nodes_needed = 0
|
||||
else:
|
||||
scaleset.size = max_size
|
||||
nodes_needed = nodes_needed - (max_size - current_size)
|
||||
scaleset.state = ScalesetState.resize
|
||||
scaleset.save()
|
||||
|
||||
else:
|
||||
continue
|
||||
|
||||
if nodes_needed == 0:
|
||||
return
|
||||
|
||||
for _ in range(
|
||||
math.ceil(
|
||||
nodes_needed
|
||||
/ min(
|
||||
Scaleset.scaleset_max_size(autoscale_config.image),
|
||||
autoscale_config.scaleset_size,
|
||||
)
|
||||
)
|
||||
):
|
||||
logging.info("Creating Scaleset for Pool %s" % (pool.name))
|
||||
max_nodes_scaleset = min(
|
||||
Scaleset.scaleset_max_size(autoscale_config.image),
|
||||
autoscale_config.scaleset_size,
|
||||
nodes_needed,
|
||||
)
|
||||
|
||||
if not autoscale_config.region:
|
||||
raise Exception("Region is missing")
|
||||
|
||||
scaleset = Scaleset.create(
|
||||
pool_name=pool.name,
|
||||
vm_sku=autoscale_config.vm_sku,
|
||||
image=autoscale_config.image,
|
||||
region=autoscale_config.region,
|
||||
size=max_nodes_scaleset,
|
||||
spot_instances=autoscale_config.spot_instances,
|
||||
tags={"pool": pool.name},
|
||||
)
|
||||
scaleset.save()
|
||||
nodes_needed -= max_nodes_scaleset
|
||||
|
||||
|
||||
def scale_down(scalesets: List[Scaleset], nodes_to_remove: int) -> None:
|
||||
logging.info("Scaling down")
|
||||
for scaleset in scalesets:
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=scaleset.scaleset_id, states=[NodeState.free]
|
||||
)
|
||||
if nodes and nodes_to_remove > 0:
|
||||
max_nodes_remove = min(len(nodes), nodes_to_remove)
|
||||
if max_nodes_remove >= scaleset.size and len(nodes) == scaleset.size:
|
||||
scaleset.state = ScalesetState.shutdown
|
||||
nodes_to_remove = nodes_to_remove - scaleset.size
|
||||
scaleset.save()
|
||||
for node in nodes:
|
||||
node.set_shutdown()
|
||||
continue
|
||||
|
||||
scaleset.size = scaleset.size - max_nodes_remove
|
||||
nodes_to_remove = nodes_to_remove - max_nodes_remove
|
||||
scaleset.state = ScalesetState.resize
|
||||
scaleset.save()
|
||||
|
||||
|
||||
def get_vm_count(tasks: List[Task]) -> int:
|
||||
count = 0
|
||||
for task in tasks:
|
||||
task_pool = task.get_pool()
|
||||
if (
|
||||
not task_pool
|
||||
or not isinstance(task_pool, Pool)
|
||||
or not isinstance(task.config.pool, TaskPool)
|
||||
):
|
||||
continue
|
||||
count += task.config.pool.count
|
||||
return count
|
||||
|
||||
|
||||
def autoscale_pool(pool: Pool) -> None:
|
||||
logging.info("autoscale: %s" % (pool.autoscale))
|
||||
if not pool.autoscale:
|
||||
return
|
||||
|
||||
# get all the tasks (count not stopped) for the pool
|
||||
tasks = Task.get_tasks_by_pool_name(pool.name)
|
||||
logging.info("Pool: %s, #Tasks %d" % (pool.name, len(tasks)))
|
||||
|
||||
num_of_tasks = get_vm_count(tasks)
|
||||
nodes_needed = max(num_of_tasks, pool.autoscale.min_size)
|
||||
if pool.autoscale.max_size:
|
||||
nodes_needed = min(nodes_needed, pool.autoscale.max_size)
|
||||
|
||||
# do scaleset logic match with pool
|
||||
# get all the scalesets for the pool
|
||||
scalesets = Scaleset.search_by_pool(pool.name)
|
||||
pool_resize = False
|
||||
for scaleset in scalesets:
|
||||
if scaleset.state in ScalesetState.modifying():
|
||||
pool_resize = True
|
||||
break
|
||||
nodes_needed = nodes_needed - scaleset.size
|
||||
|
||||
if pool_resize:
|
||||
return
|
||||
|
||||
logging.info("Pool: %s, #Nodes Needed: %d" % (pool.name, nodes_needed))
|
||||
if nodes_needed > 0:
|
||||
# resizing scaleset or creating new scaleset.
|
||||
scale_up(pool, scalesets, nodes_needed)
|
||||
elif nodes_needed < 0:
|
||||
scale_down(scalesets, abs(nodes_needed))
|
@ -73,9 +73,6 @@ class Job(BASE_JOB, ORMMixin):
|
||||
self.state = JobState.stopped
|
||||
self.save()
|
||||
|
||||
def queue_stop(self) -> None:
|
||||
self.queue(method=self.stopping)
|
||||
|
||||
def on_start(self) -> None:
|
||||
# try to keep this effectively idempotent
|
||||
if self.end_time is None:
|
||||
|
@ -36,6 +36,7 @@ from onefuzztypes.enums import (
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Container, PoolName, Region
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from .azure.table import get_client
|
||||
from .dashboard import add_event
|
||||
@ -66,6 +67,36 @@ KEY = Union[int, str, UUID, Enum]
|
||||
HOURS = 60 * 60
|
||||
|
||||
|
||||
class HasState(Protocol):
|
||||
# TODO: this should be bound tighter than Any
|
||||
# In the end, we want this to be an Enum. Specifically, one of
|
||||
# the JobState,TaskState,etc enums.
|
||||
state: Any
|
||||
|
||||
|
||||
def process_state_update(obj: HasState) -> None:
|
||||
"""
|
||||
process a single state update, if the obj
|
||||
implements a function for that state
|
||||
"""
|
||||
|
||||
func = getattr(obj, obj.state.name, None)
|
||||
if func is None:
|
||||
return
|
||||
func()
|
||||
|
||||
|
||||
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
|
||||
""" process through the state machine for an object """
|
||||
|
||||
for _ in range(max_updates):
|
||||
state = obj.state
|
||||
process_state_update(obj)
|
||||
new_state = obj.state
|
||||
if new_state == state:
|
||||
break
|
||||
|
||||
|
||||
def resolve(key: KEY) -> str:
|
||||
if isinstance(key, str):
|
||||
return key
|
||||
|
@ -69,6 +69,10 @@ from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
|
||||
|
||||
class Node(BASE_NODE, ORMMixin):
|
||||
# should only be set by Scaleset.reimage_nodes
|
||||
# should only be unset during agent_registration POST
|
||||
reimage_queued: bool = Field(default=False)
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls,
|
||||
@ -108,6 +112,21 @@ class Node(BASE_NODE, ORMMixin):
|
||||
version_query = "not (version eq '%s')" % __version__
|
||||
return cls.search(query=query, raw_unchecked_filter=version_query)
|
||||
|
||||
@classmethod
|
||||
def mark_outdated_nodes(cls) -> None:
|
||||
outdated = cls.search_outdated()
|
||||
for node in outdated:
|
||||
logging.info(
|
||||
"node is outdated: %s - node_version:%s api_version:%s",
|
||||
node.machine_id,
|
||||
node.version,
|
||||
__version__,
|
||||
)
|
||||
if node.version == "1.0.0":
|
||||
node.to_reimage(done=True)
|
||||
else:
|
||||
node.to_reimage()
|
||||
|
||||
@classmethod
|
||||
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
|
||||
nodes = cls.search(query={"machine_id": [machine_id]})
|
||||
@ -195,9 +214,24 @@ class Node(BASE_NODE, ORMMixin):
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
if self.delete_requested or self.reimage_requested:
|
||||
if self.state in NodeState.ready_for_reset():
|
||||
logging.info(
|
||||
"can_schedule should be recycled. machine_id:%s", self.machine_id
|
||||
"can_schedule node is set for reset. machine_id:%s", self.machine_id
|
||||
)
|
||||
return False
|
||||
|
||||
if self.delete_requested:
|
||||
logging.info(
|
||||
"can_schedule is set to be deleted. machine_id:%s",
|
||||
self.machine_id,
|
||||
)
|
||||
self.stop()
|
||||
return False
|
||||
|
||||
if self.reimage_requested:
|
||||
logging.info(
|
||||
"can_schedule is set to be reimaged. machine_id:%s",
|
||||
self.machine_id,
|
||||
)
|
||||
self.stop()
|
||||
return False
|
||||
@ -682,25 +716,11 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
to_reimage = []
|
||||
to_delete = []
|
||||
|
||||
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
|
||||
for node in outdated:
|
||||
logging.info(
|
||||
"node is outdated: %s - node_version:%s api_version:%s",
|
||||
node.machine_id,
|
||||
node.version,
|
||||
__version__,
|
||||
)
|
||||
if node.version == "1.0.0":
|
||||
node.state = NodeState.done
|
||||
to_reimage.append(node)
|
||||
else:
|
||||
node.to_reimage()
|
||||
|
||||
nodes = Node.search_states(
|
||||
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
|
||||
)
|
||||
|
||||
if not outdated and not nodes:
|
||||
if not nodes:
|
||||
logging.info("no nodes need updating: %s", self.scaleset_id)
|
||||
return False
|
||||
|
||||
@ -719,7 +739,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
|
||||
node.set_halt()
|
||||
to_delete.append(node)
|
||||
else:
|
||||
elif not node.reimage_queued:
|
||||
# only add nodes that are not already set to reschedule
|
||||
to_reimage.append(node)
|
||||
|
||||
# Perform operations until they fail due to scaleset getting locked
|
||||
@ -833,6 +854,9 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
"unable to reimage nodes: %s:%s - %s"
|
||||
% (self.scaleset_id, machine_ids, result)
|
||||
)
|
||||
for node in nodes:
|
||||
node.reimage_queued = True
|
||||
node.save()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
@ -855,7 +879,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.save()
|
||||
else:
|
||||
logging.info("scaleset deleted: %s", self.scaleset_id)
|
||||
self.state = ScalesetState.halt
|
||||
self.delete()
|
||||
|
||||
@classmethod
|
||||
|
@ -27,7 +27,7 @@ from .azure.ip import get_public_ip
|
||||
from .azure.queue import get_queue_sas
|
||||
from .azure.vm import VM
|
||||
from .extension import proxy_manager_extensions
|
||||
from .orm import HOURS, MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .proxy_forward import ProxyForward
|
||||
|
||||
PROXY_SKU = "Standard_B2s"
|
||||
@ -210,9 +210,6 @@ class Proxy(ORMMixin):
|
||||
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
|
||||
)
|
||||
|
||||
def queue_stop(self, count: int) -> None:
|
||||
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
|
||||
query: QueryFilter = {}
|
||||
|
@ -4,6 +4,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
@ -18,7 +19,7 @@ from .azure.creds import get_base_region, get_func_storage
|
||||
from .azure.ip import get_public_ip
|
||||
from .azure.vm import VM
|
||||
from .extension import repro_extensions
|
||||
from .orm import HOURS, ORMMixin, QueryFilter
|
||||
from .orm import ORMMixin, QueryFilter
|
||||
from .reports import get_report
|
||||
from .tasks.main import Task
|
||||
|
||||
@ -205,9 +206,6 @@ class Repro(BASE_REPRO, ORMMixin):
|
||||
logging.info("saved repro script")
|
||||
return None
|
||||
|
||||
def queue_stop(self, count: int) -> None:
|
||||
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
|
||||
|
||||
@classmethod
|
||||
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
|
||||
query: QueryFilter = {}
|
||||
@ -228,10 +226,18 @@ class Repro(BASE_REPRO, ORMMixin):
|
||||
return task
|
||||
|
||||
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
|
||||
if vm.end_time is None:
|
||||
vm.end_time = datetime.utcnow() + timedelta(hours=config.duration)
|
||||
vm.save()
|
||||
vm.queue_stop(config.duration)
|
||||
|
||||
return vm
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Repro"]:
|
||||
# unlike jobs/tasks, the entry is deleted from the backing table upon stop
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
return cls.search(raw_unchecked_filter=time_filter)
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("vm_id", None)
|
||||
|
@ -118,9 +118,6 @@ class Task(BASE_TASK, ORMMixin):
|
||||
self.state = TaskState.stopped
|
||||
self.save()
|
||||
|
||||
def queue_stop(self) -> None:
|
||||
self.queue(method=self.stopping)
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
|
||||
@ -165,7 +162,7 @@ class Task(BASE_TASK, ORMMixin):
|
||||
task_pool = task.get_pool()
|
||||
if not task_pool:
|
||||
continue
|
||||
if pool_name == task_pool.name and task.state in TaskState.available():
|
||||
if pool_name == task_pool.name:
|
||||
pool_tasks.append(task)
|
||||
|
||||
return pool_tasks
|
||||
|
@ -7,17 +7,34 @@ import logging
|
||||
from typing import Dict, List
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import OS, TaskState
|
||||
from onefuzztypes.enums import OS, PoolState, TaskState
|
||||
from onefuzztypes.models import WorkSet, WorkUnit
|
||||
|
||||
from ..azure.containers import blob_exists, get_container_sas_url, save_blob
|
||||
from ..azure.creds import get_func_storage
|
||||
from ..pools import Pool
|
||||
from .config import build_task_config, get_setup_container
|
||||
from .main import Task
|
||||
|
||||
HOURS = 60 * 60
|
||||
|
||||
|
||||
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
|
||||
if pool.state not in PoolState.available():
|
||||
logging.info(
|
||||
"pool not available for work: %s state: %s", pool.name, pool.state.name
|
||||
)
|
||||
return False
|
||||
|
||||
for _ in range(count):
|
||||
if not pool.schedule_workset(workset):
|
||||
logging.error(
|
||||
"unable to schedule workset. pool:%s workset:%s", pool.name, workset
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def schedule_tasks() -> None:
|
||||
to_schedule: Dict[UUID, List[Task]] = {}
|
||||
|
||||
@ -82,7 +99,7 @@ def schedule_tasks() -> None:
|
||||
)
|
||||
|
||||
# For now, only offer singleton work sets.
|
||||
work_set = WorkSet(
|
||||
workset = WorkSet(
|
||||
reboot=reboot,
|
||||
script=(setup_script is not None),
|
||||
setup_url=setup_url,
|
||||
@ -94,7 +111,6 @@ def schedule_tasks() -> None:
|
||||
logging.info("unable to find pool for task: %s", task.task_id)
|
||||
continue
|
||||
|
||||
for _ in range(count):
|
||||
pool.schedule_workset(work_set)
|
||||
task.state = TaskState.scheduled
|
||||
task.save()
|
||||
if schedule_workset(workset, pool, count):
|
||||
task.state = TaskState.scheduled
|
||||
task.save()
|
||||
|
Reference in New Issue
Block a user