Remove update_event as a single event loop for the system (#160)

This commit is contained in:
bmc-msft
2020-10-16 21:42:35 -04:00
committed by GitHub
parent 9fa25803ab
commit 75f29b9f2e
24 changed files with 418 additions and 324 deletions

View File

@ -98,6 +98,7 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
node.version = registration_request.version
node.reimage_requested = False
node.state = NodeState.init
node.reimage_queued = False
else:
node = Node(
pool_name=registration_request.pool_name,

View File

@ -7,12 +7,11 @@ import logging
import math
from typing import List
import azure.functions as func
from onefuzztypes.enums import NodeState, PoolState, ScalesetState
from onefuzztypes.enums import NodeState, ScalesetState
from onefuzztypes.models import AutoScaleConfig, TaskPool
from ..onefuzzlib.pools import Node, Pool, Scaleset
from ..onefuzzlib.tasks.main import Task
from .pools import Node, Pool, Scaleset
from .tasks.main import Task
def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
@ -22,11 +21,11 @@ def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
return
for scaleset in scalesets:
if scaleset.state == ScalesetState.running:
if scaleset.state in [ScalesetState.running, ScalesetState.resize]:
max_size = min(scaleset.max_size(), autoscale_config.scaleset_size)
logging.info(
"Sacleset id: %s, Scaleset size: %d, max_size: %d"
"scaleset:%s size:%d max_size:%d"
% (scaleset.scaleset_id, scaleset.size, max_size)
)
if scaleset.size < max_size:
@ -114,12 +113,10 @@ def get_vm_count(tasks: List[Task]) -> int:
return count
def main(mytimer: func.TimerRequest) -> None: # noqa: F841
pools = Pool.search_states(states=PoolState.available())
for pool in pools:
def autoscale_pool(pool: Pool) -> None:
logging.info("autoscale: %s" % (pool.autoscale))
if not pool.autoscale:
continue
return
# get all the tasks (count not stopped) for the pool
tasks = Task.get_tasks_by_pool_name(pool.name)
@ -141,7 +138,7 @@ def main(mytimer: func.TimerRequest) -> None: # noqa: F841
nodes_needed = nodes_needed - scaleset.size
if pool_resize:
continue
return
logging.info("Pool: %s, #Nodes Needed: %d" % (pool.name, nodes_needed))
if nodes_needed > 0:

View File

@ -73,9 +73,6 @@ class Job(BASE_JOB, ORMMixin):
self.state = JobState.stopped
self.save()
def queue_stop(self) -> None:
self.queue(method=self.stopping)
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:

View File

@ -36,6 +36,7 @@ from onefuzztypes.enums import (
from onefuzztypes.models import Error
from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel, Field
from typing_extensions import Protocol
from .azure.table import get_client
from .dashboard import add_event
@ -66,6 +67,36 @@ KEY = Union[int, str, UUID, Enum]
HOURS = 60 * 60
class HasState(Protocol):
# TODO: this should be bound tighter than Any
# In the end, we want this to be an Enum. Specifically, one of
# the JobState,TaskState,etc enums.
state: Any
def process_state_update(obj: HasState) -> None:
"""
process a single state update, if the obj
implements a function for that state
"""
func = getattr(obj, obj.state.name, None)
if func is None:
return
func()
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
""" process through the state machine for an object """
for _ in range(max_updates):
state = obj.state
process_state_update(obj)
new_state = obj.state
if new_state == state:
break
def resolve(key: KEY) -> str:
if isinstance(key, str):
return key

View File

@ -69,6 +69,10 @@ from .orm import MappingIntStrAny, ORMMixin, QueryFilter
class Node(BASE_NODE, ORMMixin):
# should only be set by Scaleset.reimage_nodes
# should only be unset during agent_registration POST
reimage_queued: bool = Field(default=False)
@classmethod
def search_states(
cls,
@ -108,6 +112,21 @@ class Node(BASE_NODE, ORMMixin):
version_query = "not (version eq '%s')" % __version__
return cls.search(query=query, raw_unchecked_filter=version_query)
@classmethod
def mark_outdated_nodes(cls) -> None:
outdated = cls.search_outdated()
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.to_reimage(done=True)
else:
node.to_reimage()
@classmethod
def get_by_machine_id(cls, machine_id: UUID) -> Optional["Node"]:
nodes = cls.search(query={"machine_id": [machine_id]})
@ -195,9 +214,24 @@ class Node(BASE_NODE, ORMMixin):
self.stop()
return False
if self.delete_requested or self.reimage_requested:
if self.state in NodeState.ready_for_reset():
logging.info(
"can_schedule should be recycled. machine_id:%s", self.machine_id
"can_schedule node is set for reset. machine_id:%s", self.machine_id
)
return False
if self.delete_requested:
logging.info(
"can_schedule is set to be deleted. machine_id:%s",
self.machine_id,
)
self.stop()
return False
if self.reimage_requested:
logging.info(
"can_schedule is set to be reimaged. machine_id:%s",
self.machine_id,
)
self.stop()
return False
@ -682,25 +716,11 @@ class Scaleset(BASE_SCALESET, ORMMixin):
to_reimage = []
to_delete = []
outdated = Node.search_outdated(scaleset_id=self.scaleset_id)
for node in outdated:
logging.info(
"node is outdated: %s - node_version:%s api_version:%s",
node.machine_id,
node.version,
__version__,
)
if node.version == "1.0.0":
node.state = NodeState.done
to_reimage.append(node)
else:
node.to_reimage()
nodes = Node.search_states(
scaleset_id=self.scaleset_id, states=NodeState.ready_for_reset()
)
if not outdated and not nodes:
if not nodes:
logging.info("no nodes need updating: %s", self.scaleset_id)
return False
@ -719,7 +739,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
if ScalesetShrinkQueue(self.scaleset_id).should_shrink():
node.set_halt()
to_delete.append(node)
else:
elif not node.reimage_queued:
# only add nodes that are not already set to reschedule
to_reimage.append(node)
# Perform operations until they fail due to scaleset getting locked
@ -833,6 +854,9 @@ class Scaleset(BASE_SCALESET, ORMMixin):
"unable to reimage nodes: %s:%s - %s"
% (self.scaleset_id, machine_ids, result)
)
for node in nodes:
node.reimage_queued = True
node.save()
def shutdown(self) -> None:
size = get_vmss_size(self.scaleset_id)
@ -855,7 +879,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.save()
else:
logging.info("scaleset deleted: %s", self.scaleset_id)
self.state = ScalesetState.halt
self.delete()
@classmethod

View File

@ -27,7 +27,7 @@ from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas
from .azure.vm import VM
from .extension import proxy_manager_extensions
from .orm import HOURS, MappingIntStrAny, ORMMixin, QueryFilter
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
from .proxy_forward import ProxyForward
PROXY_SKU = "Standard_B2s"
@ -210,9 +210,6 @@ class Proxy(ORMMixin):
account_id=os.environ["ONEFUZZ_FUNC_STORAGE"],
)
def queue_stop(self, count: int) -> None:
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Proxy"]:
query: QueryFilter = {}

View File

@ -4,6 +4,7 @@
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Union
from azure.mgmt.compute.models import VirtualMachine
@ -18,7 +19,7 @@ from .azure.creds import get_base_region, get_func_storage
from .azure.ip import get_public_ip
from .azure.vm import VM
from .extension import repro_extensions
from .orm import HOURS, ORMMixin, QueryFilter
from .orm import ORMMixin, QueryFilter
from .reports import get_report
from .tasks.main import Task
@ -205,9 +206,6 @@ class Repro(BASE_REPRO, ORMMixin):
logging.info("saved repro script")
return None
def queue_stop(self, count: int) -> None:
self.queue(method=self.stopping, visibility_timeout=count * HOURS)
@classmethod
def search_states(cls, *, states: Optional[List[VmState]] = None) -> List["Repro"]:
query: QueryFilter = {}
@ -228,10 +226,18 @@ class Repro(BASE_REPRO, ORMMixin):
return task
vm = cls(config=config, task_id=task.task_id, os=task.os, auth=build_auth())
if vm.end_time is None:
vm.end_time = datetime.utcnow() + timedelta(hours=config.duration)
vm.save()
vm.queue_stop(config.duration)
return vm
@classmethod
def search_expired(cls) -> List["Repro"]:
# unlike jobs/tasks, the entry is deleted from the backing table upon stop
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(raw_unchecked_filter=time_filter)
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("vm_id", None)

View File

@ -118,9 +118,6 @@ class Task(BASE_TASK, ORMMixin):
self.state = TaskState.stopped
self.save()
def queue_stop(self) -> None:
self.queue(method=self.stopping)
@classmethod
def search_states(
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
@ -165,7 +162,7 @@ class Task(BASE_TASK, ORMMixin):
task_pool = task.get_pool()
if not task_pool:
continue
if pool_name == task_pool.name and task.state in TaskState.available():
if pool_name == task_pool.name:
pool_tasks.append(task)
return pool_tasks

View File

@ -7,17 +7,34 @@ import logging
from typing import Dict, List
from uuid import UUID
from onefuzztypes.enums import OS, TaskState
from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit
from ..azure.containers import blob_exists, get_container_sas_url, save_blob
from ..azure.creds import get_func_storage
from ..pools import Pool
from .config import build_task_config, get_setup_container
from .main import Task
HOURS = 60 * 60
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
if pool.state not in PoolState.available():
logging.info(
"pool not available for work: %s state: %s", pool.name, pool.state.name
)
return False
for _ in range(count):
if not pool.schedule_workset(workset):
logging.error(
"unable to schedule workset. pool:%s workset:%s", pool.name, workset
)
return False
return True
def schedule_tasks() -> None:
to_schedule: Dict[UUID, List[Task]] = {}
@ -82,7 +99,7 @@ def schedule_tasks() -> None:
)
# For now, only offer singleton work sets.
work_set = WorkSet(
workset = WorkSet(
reboot=reboot,
script=(setup_script is not None),
setup_url=setup_url,
@ -94,7 +111,6 @@ def schedule_tasks() -> None:
logging.info("unable to find pool for task: %s", task.task_id)
continue
for _ in range(count):
pool.schedule_workset(work_set)
if schedule_workset(workset, pool, count):
task.state = TaskState.scheduled
task.save()

View File

@ -32,5 +32,6 @@ PyJWT~=1.7.1
requests~=2.24.0
memoization~=0.3.1
github3.py~=1.3.0
typing-extensions~=3.7.4.3
# onefuzz types version is set during build
onefuzztypes==0.0.0

View File

@ -1,72 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.enums import JobState, NodeState, PoolState, TaskState, VmState
from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.jobs import Job
from ..onefuzzlib.pools import Node, Pool
from ..onefuzzlib.proxy import Proxy
from ..onefuzzlib.repro import Repro
from ..onefuzzlib.tasks.main import Task
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
proxies = Proxy.search_states(states=VmState.needs_work())
for proxy in proxies:
logging.info("requeueing update proxy vm: %s", proxy.region)
proxy.queue()
vms = Repro.search_states(states=VmState.needs_work())
for vm in vms:
logging.info("requeueing update vm: %s", vm.vm_id)
vm.queue()
tasks = Task.search_states(states=TaskState.needs_work())
for task in tasks:
logging.info("requeueing update task: %s", task.task_id)
task.queue()
jobs = Job.search_states(states=JobState.needs_work())
for job in jobs:
logging.info("requeueing update job: %s", job.job_id)
job.queue()
pools = Pool.search_states(states=PoolState.needs_work())
for pool in pools:
logging.info("queuing update pool: %s (%s)", pool.pool_id, pool.name)
pool.queue()
nodes = Node.search_states(states=NodeState.needs_work())
for node in nodes:
logging.info("queuing update node: %s", node.machine_id)
node.queue()
expired_tasks = Task.search_expired()
for task in expired_tasks:
logging.info("queuing stop for task: %s", task.job_id)
task.queue_stop()
expired_jobs = Job.search_expired()
for job in expired_jobs:
logging.info("queuing stop for job: %s", job.job_id)
job.queue_stop()
# Reminder, proxies are created on-demand. If something is "wrong" with
# a proxy, the plan is: delete and recreate it.
for proxy in Proxy.search():
if not proxy.is_alive():
logging.error("proxy alive check failed, stopping: %s", proxy.region)
proxy.state = VmState.stopping
proxy.save()
else:
proxy.save_proxy_config()
event = get_event()
if event:
dashboard.set(event)

View File

@ -1,11 +0,0 @@
{
"scriptFile": "__init__.py",
"bindings": [
{
"name": "mytimer",
"type": "timerTrigger",
"direction": "in",
"schedule": "00:01:00"
}
]
}

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.enums import VmState
from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.orm import process_state_updates
from ..onefuzzlib.proxy import Proxy
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
# Reminder, proxies are created on-demand. If something is "wrong" with
# a proxy, the plan is: delete and recreate it.
for proxy in Proxy.search():
if not proxy.is_alive():
logging.error("proxy alive check failed, stopping: %s", proxy.region)
proxy.state = VmState.stopping
proxy.save()
else:
proxy.save_proxy_config()
if proxy.state in VmState.needs_work():
logging.info("update proxy vm: %s", proxy.region)
process_state_updates(proxy)
event = get_event()
if event:
dashboard.set(event)

View File

@ -0,0 +1,17 @@
{
"bindings": [
{
"direction": "in",
"name": "mytimer",
"schedule": "00:00:30",
"type": "timerTrigger"
},
{
"type": "signalR",
"direction": "out",
"name": "dashboard",
"hubName": "dashboard"
}
],
"scriptFile": "__init__.py"
}

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.enums import VmState
from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.orm import process_state_updates
from ..onefuzzlib.repro import Repro
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
expired = Repro.search_expired()
for repro in expired:
logging.info("stopping repro: %s", repro.vm_id)
repro.stopping()
expired_vm_ids = [x.vm_id for x in expired]
for repro in Repro.search_states(states=VmState.needs_work()):
if repro.vm_id in expired_vm_ids:
# this VM already got processed during the expired phase
continue
logging.info("update repro: %s", repro.vm_id)
process_state_updates(repro)
event = get_event()
if event:
dashboard.set(event)

View File

@ -3,7 +3,7 @@
{
"direction": "in",
"name": "mytimer",
"schedule": "0 */5 * * * *",
"schedule": "00:00:30",
"type": "timerTrigger"
},
{

View File

@ -1,17 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import azure.functions as func
from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.tasks.scheduler import schedule_tasks
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
schedule_tasks()
event = get_event()
if event:
dashboard.set(event)

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import azure.functions as func
from onefuzztypes.enums import JobState, TaskState
from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.jobs import Job
from ..onefuzzlib.orm import process_state_updates
from ..onefuzzlib.tasks.main import Task
from ..onefuzzlib.tasks.scheduler import schedule_tasks
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
expired_tasks = Task.search_expired()
for task in expired_tasks:
logging.info("stopping expired task: %s", task.job_id)
task.stopping()
expired_jobs = Job.search_expired()
for job in expired_jobs:
logging.info("stopping expired job: %s", job.job_id)
job.stopping()
jobs = Job.search_states(states=JobState.needs_work())
for job in jobs:
logging.info("update job: %s", job.job_id)
process_state_updates(job)
tasks = Task.search_states(states=TaskState.needs_work())
for task in tasks:
logging.info("update task: %s", task.task_id)
process_state_updates(task)
schedule_tasks()
event = get_event()
if event:
dashboard.set(event)

View File

@ -6,43 +6,44 @@
import logging
import azure.functions as func
from onefuzztypes.enums import ScalesetState
from onefuzztypes.enums import NodeState, PoolState
from ..onefuzzlib.autoscale import autoscale_pool
from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.pools import Scaleset
from ..onefuzzlib.orm import process_state_updates
from ..onefuzzlib.pools import Node, Pool, Scaleset
def process_scaleset(scaleset: Scaleset) -> None:
logging.debug("checking scaleset for updates: %s", scaleset.scaleset_id)
if scaleset.state == ScalesetState.resize:
scaleset.resize()
# if the scaleset is touched during cleanup, don't continue to process it
if scaleset.cleanup_nodes():
logging.debug("scaleset needed cleanup: %s", scaleset.scaleset_id)
return
if (
scaleset.state in ScalesetState.needs_work()
and scaleset.state != ScalesetState.resize
):
logging.info(
"exec scaleset state: %s - %s",
scaleset.scaleset_id,
scaleset.state,
)
if hasattr(scaleset, scaleset.state.name):
getattr(scaleset, scaleset.state.name)()
return
process_state_updates(scaleset)
def main(mytimer: func.TimerRequest, dashboard: func.Out[str]) -> None: # noqa: F841
Node.mark_outdated_nodes()
nodes = Node.search_states(states=NodeState.needs_work())
for node in nodes:
logging.info("update node: %s", node.machine_id)
process_state_updates(node)
scalesets = Scaleset.search()
for scaleset in scalesets:
process_scaleset(scaleset)
pools = Pool.search()
for pool in pools:
if pool.state in PoolState.needs_work():
logging.info("update pool: %s (%s)", pool.pool_id, pool.name)
process_state_updates(pool)
elif pool.state in PoolState.available() and pool.autoscale:
autoscale_pool(pool)
event = get_event()
if event:
dashboard.set(event)

View File

@ -10,7 +10,7 @@ from uuid import UUID
from __app__.onefuzzlib.orm import ORMMixin, build_filters
class TestOrm(ORMMixin):
class BasicOrm(ORMMixin):
a: int
b: UUID
c: str
@ -27,38 +27,38 @@ class TestQueryBuilder(unittest.TestCase):
self.maxDiff = 999999999999999
self.assertEqual(
build_filters(TestOrm, {"a": [1]}), ("a eq 1", {}), "handle integer"
build_filters(BasicOrm, {"a": [1]}), ("a eq 1", {}), "handle integer"
)
self.assertEqual(
build_filters(
TestOrm, {"b": [UUID("06aa1e71-b025-4325-9983-4b3ce2de12ea")]}
BasicOrm, {"b": [UUID("06aa1e71-b025-4325-9983-4b3ce2de12ea")]}
),
("b eq '06aa1e71-b025-4325-9983-4b3ce2de12ea'", {}),
"handle UUID",
)
self.assertEqual(
build_filters(TestOrm, {"a": ["a"]}), (None, {"a": ["a"]}), "handle str"
build_filters(BasicOrm, {"a": ["a"]}), (None, {"a": ["a"]}), "handle str"
)
self.assertEqual(
build_filters(TestOrm, {"a": [1, 2]}),
build_filters(BasicOrm, {"a": [1, 2]}),
("(a eq 1 or a eq 2)", {}),
"multiple values",
)
self.assertEqual(
build_filters(TestOrm, {"a": ["b"], "c": ["d"]}),
build_filters(BasicOrm, {"a": ["b"], "c": ["d"]}),
(None, {"a": ["b"], "c": ["d"]}),
"multiple fields",
)
self.assertEqual(
build_filters(TestOrm, {"a": [1, 2], "c": [3]}),
build_filters(BasicOrm, {"a": [1, 2], "c": [3]}),
("(a eq 1 or a eq 2) and c eq 3", {}),
"multiple fields and values",
)
self.assertEqual(
build_filters(
TestOrm,
BasicOrm,
{
"a": ["b"],
"b": [1],
@ -70,13 +70,13 @@ class TestQueryBuilder(unittest.TestCase):
)
self.assertEqual(
build_filters(TestOrm, {"d": [1, 2], "e": [3]}),
build_filters(BasicOrm, {"d": [1, 2], "e": [3]}),
("(PartitionKey eq 1 or PartitionKey eq 2) and RowKey eq 3", {}),
"query on keyfields",
)
with self.assertRaises(ValueError):
build_filters(TestOrm, {"test1": ["b", "c"], "test2": ["d"]})
build_filters(BasicOrm, {"test1": ["b", "c"], "test2": ["d"]})
if __name__ == "__main__":

View File

@ -256,7 +256,7 @@ class PoolState(Enum):
@classmethod
def available(cls) -> List["PoolState"]:
""" set of states that indicate if it's available for work """
return [cls.init, cls.running]
return [cls.running]
class ScalesetState(Enum):

View File

@ -582,6 +582,7 @@ class Repro(BaseModel):
os: OS
error: Optional[Error]
ip: Optional[str]
end_time: Optional[datetime]
class ExitStatus(BaseModel):