mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-14 11:08:06 +00:00
Colocate tasks (#402)
Enables co-locating multiple tasks in a given work-set. Tasks are bucketed by the following: * OS * job id * setup container * VM SKU & image (used in pre-1.0 style tasks) * pool name (used in 1.0+ style tasks) * if the task needs rebooting after the task setup script executes. Additionally, a task will end up in a unique bucket if any of the following are true: * The task is set to run on more than one VM * The task is missing the `task.config.colocate` flag (all tasks created prior to this functionality) or the value is False This updates the libfuzzer template to make use of colocation. Users can specify co-locating all of the tasks *or* co-locating the secondary tasks.
This commit is contained in:
@ -413,6 +413,10 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
||||
"items": {
|
||||
"$ref": "#/definitions/TaskDebugFlag"
|
||||
}
|
||||
},
|
||||
"colocate": {
|
||||
"title": "Colocate",
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
@ -1044,6 +1048,10 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
||||
"items": {
|
||||
"$ref": "#/definitions/TaskDebugFlag"
|
||||
}
|
||||
},
|
||||
"colocate": {
|
||||
"title": "Colocate",
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
|
@ -130,6 +130,7 @@ def on_state_update(
|
||||
# if tasks are running on the node when it reports as Done
|
||||
# those are stopped early
|
||||
node.mark_tasks_stopped_early()
|
||||
node.to_reimage(done=True)
|
||||
|
||||
# Model-validated.
|
||||
#
|
||||
@ -242,7 +243,6 @@ def on_worker_event_done(machine_id: UUID, event: WorkerDoneEvent) -> Result[Non
|
||||
node.debug_keep_node = True
|
||||
node.save()
|
||||
|
||||
node.to_reimage(done=True)
|
||||
task_event = TaskEvent(
|
||||
task_id=task.task_id, machine_id=machine_id, event_data=WorkerEvent(done=event)
|
||||
)
|
||||
|
@ -237,6 +237,26 @@ libfuzzer_linux = JobTemplate(
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="colocate",
|
||||
help="Run all of the tasks on the same node",
|
||||
type=UserFieldType.Bool,
|
||||
default=True,
|
||||
locations=[
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/0/colocate",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/1/colocate",
|
||||
),
|
||||
UserFieldLocation(
|
||||
op=UserFieldOperation.add,
|
||||
path="/tasks/2/colocate",
|
||||
),
|
||||
],
|
||||
),
|
||||
UserField(
|
||||
name="expect_crash_on_failure",
|
||||
help="Require crashes upon non-zero exits from libfuzzer",
|
||||
|
@ -28,7 +28,7 @@ from ..webhooks import Webhook
|
||||
|
||||
|
||||
class Task(BASE_TASK, ORMMixin):
|
||||
def ready_to_schedule(self) -> bool:
|
||||
def check_prereq_tasks(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for task_id in self.config.prereq_tasks:
|
||||
task = Task.get_by_task_id(task_id)
|
||||
|
@ -4,24 +4,30 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from uuid import UUID
|
||||
from typing import Dict, Generator, List, Optional, Tuple, TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from onefuzztypes.enums import OS, PoolState, TaskState
|
||||
from onefuzztypes.models import WorkSet, WorkUnit
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..azure.containers import (
|
||||
StorageType,
|
||||
blob_exists,
|
||||
get_container_sas_url,
|
||||
save_blob,
|
||||
)
|
||||
from ..azure.containers import StorageType, blob_exists, get_container_sas_url
|
||||
from ..pools import Pool
|
||||
from .config import build_task_config, get_setup_container
|
||||
from .main import Task
|
||||
|
||||
HOURS = 60 * 60
|
||||
|
||||
# TODO: eventually, this should be tied to the pool.
|
||||
MAX_TASKS_PER_SET = 10
|
||||
|
||||
|
||||
A = TypeVar("A")
|
||||
|
||||
|
||||
def chunks(items: List[A], size: int) -> Generator[List[A], None, None]:
|
||||
return (items[x : x + size] for x in range(0, len(items), size))
|
||||
|
||||
|
||||
def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
|
||||
if pool.state not in PoolState.available():
|
||||
@ -39,29 +45,76 @@ def schedule_workset(workset: WorkSet, pool: Pool, count: int) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def schedule_tasks() -> None:
|
||||
to_schedule: Dict[UUID, List[Task]] = {}
|
||||
# TODO - Once Pydantic supports hashable models, the Tuple should be replaced
|
||||
# with a model.
|
||||
#
|
||||
# For info: https://github.com/samuelcolvin/pydantic/pull/1881
|
||||
|
||||
not_ready_count = 0
|
||||
|
||||
for task in Task.search_states(states=[TaskState.waiting]):
|
||||
if not task.ready_to_schedule():
|
||||
not_ready_count += 1
|
||||
continue
|
||||
def bucket_tasks(tasks: List[Task]) -> Dict[Tuple, List[Task]]:
|
||||
# buckets are hashed by:
|
||||
# OS, JOB ID, vm sku & image (if available), pool name (if available),
|
||||
# if the setup script requires rebooting, and a 'unique' value
|
||||
#
|
||||
# The unique value is set based on the following conditions:
|
||||
# * if the task is set to run on more than one VM, than we assume it can't be shared
|
||||
# * if the task is missing the 'colocate' flag or it's set to False
|
||||
|
||||
if task.job_id not in to_schedule:
|
||||
to_schedule[task.job_id] = []
|
||||
to_schedule[task.job_id].append(task)
|
||||
|
||||
if not to_schedule and not_ready_count > 0:
|
||||
logging.info("tasks not ready: %d", not_ready_count)
|
||||
|
||||
for tasks in to_schedule.values():
|
||||
# TODO: for now, we're only scheduling one task per VM.
|
||||
buckets: Dict[Tuple, List[Task]] = {}
|
||||
|
||||
for task in tasks:
|
||||
vm: Optional[Tuple[str, str]] = None
|
||||
pool: Optional[str] = None
|
||||
unique: Optional[UUID] = None
|
||||
|
||||
# check for multiple VMs for pre-1.0.0 tasks
|
||||
if task.config.vm:
|
||||
vm = (task.config.vm.sku, task.config.vm.image)
|
||||
if task.config.vm.count > 1:
|
||||
unique = uuid4()
|
||||
|
||||
# check for multiple VMs for 1.0.0 and later tasks
|
||||
if task.config.pool:
|
||||
pool = task.config.pool.pool_name
|
||||
if task.config.pool.count > 1:
|
||||
unique = uuid4()
|
||||
|
||||
if not task.config.colocate:
|
||||
unique = uuid4()
|
||||
|
||||
key = (
|
||||
task.os,
|
||||
task.job_id,
|
||||
vm,
|
||||
pool,
|
||||
get_setup_container(task.config),
|
||||
task.config.task.reboot_after_setup,
|
||||
unique,
|
||||
)
|
||||
if key not in buckets:
|
||||
buckets[key] = []
|
||||
buckets[key].append(task)
|
||||
|
||||
return buckets
|
||||
|
||||
|
||||
class BucketConfig(BaseModel):
|
||||
count: int
|
||||
reboot: bool
|
||||
setup_url: str
|
||||
setup_script: Optional[str]
|
||||
pool: Pool
|
||||
|
||||
|
||||
def build_work_unit(task: Task) -> Optional[Tuple[BucketConfig, WorkUnit]]:
|
||||
pool = task.get_pool()
|
||||
if not pool:
|
||||
logging.info("unable to find pool for task: %s", task.task_id)
|
||||
return None
|
||||
|
||||
logging.info("scheduling task: %s", task.task_id)
|
||||
agent_config = build_task_config(task.job_id, task.task_id, task.config)
|
||||
|
||||
task_config = build_task_config(task.job_id, task.task_id, task.config)
|
||||
|
||||
setup_container = get_setup_container(task.config)
|
||||
setup_url = get_container_sas_url(
|
||||
@ -79,48 +132,111 @@ def schedule_tasks() -> None:
|
||||
):
|
||||
setup_script = "setup.sh"
|
||||
|
||||
save_blob(
|
||||
"task-configs",
|
||||
"%s/config.json" % task.task_id,
|
||||
agent_config.json(exclude_none=True),
|
||||
StorageType.config,
|
||||
)
|
||||
reboot = False
|
||||
count = 1
|
||||
if task.config.pool:
|
||||
count = task.config.pool.count
|
||||
|
||||
# NOTE: "is True" is required to handle Optional[bool]
|
||||
reboot = task.config.task.reboot_after_setup is True
|
||||
elif task.config.vm:
|
||||
# this branch should go away when we stop letting people specify
|
||||
# VM configs directly.
|
||||
count = task.config.vm.count
|
||||
|
||||
# NOTE: "is True" is required to handle Optional[bool]
|
||||
reboot = (
|
||||
task.config.vm.reboot_after_setup is True
|
||||
or task.config.task.reboot_after_setup is True
|
||||
)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
task_config = agent_config
|
||||
task_config_json = task_config.json()
|
||||
work_unit = WorkUnit(
|
||||
job_id=task_config.job_id,
|
||||
task_id=task_config.task_id,
|
||||
task_type=task_config.task_type,
|
||||
config=task_config_json,
|
||||
config=task_config.json(),
|
||||
)
|
||||
|
||||
# For now, only offer singleton work sets.
|
||||
workset = WorkSet(
|
||||
bucket_config = BucketConfig(
|
||||
pool=pool,
|
||||
count=count,
|
||||
reboot=reboot,
|
||||
script=(setup_script is not None),
|
||||
setup_script=setup_script,
|
||||
setup_url=setup_url,
|
||||
work_units=[work_unit],
|
||||
)
|
||||
|
||||
pool = task.get_pool()
|
||||
if not pool:
|
||||
logging.info("unable to find pool for task: %s", task.task_id)
|
||||
return bucket_config, work_unit
|
||||
|
||||
|
||||
def build_work_set(tasks: List[Task]) -> Optional[Tuple[BucketConfig, WorkSet]]:
|
||||
task_ids = [x.task_id for x in tasks]
|
||||
|
||||
bucket_config: Optional[BucketConfig] = None
|
||||
work_units = []
|
||||
|
||||
for task in tasks:
|
||||
if task.config.prereq_tasks:
|
||||
# if all of the prereqs are in this bucket, they will be
|
||||
# scheduled together
|
||||
if not all([task_id in task_ids for task_id in task.config.prereq_tasks]):
|
||||
if not task.check_prereq_tasks():
|
||||
continue
|
||||
|
||||
if schedule_workset(workset, pool, count):
|
||||
result = build_work_unit(task)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
new_bucket_config, work_unit = result
|
||||
if bucket_config is None:
|
||||
bucket_config = new_bucket_config
|
||||
else:
|
||||
if bucket_config != new_bucket_config:
|
||||
raise Exception(
|
||||
f"bucket configs differ: {bucket_config} VS {new_bucket_config}"
|
||||
)
|
||||
|
||||
work_units.append(work_unit)
|
||||
|
||||
if bucket_config:
|
||||
work_set = WorkSet(
|
||||
reboot=bucket_config.reboot,
|
||||
script=(bucket_config.setup_script is not None),
|
||||
setup_url=bucket_config.setup_url,
|
||||
work_units=work_units,
|
||||
)
|
||||
return (bucket_config, work_set)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def schedule_tasks() -> None:
|
||||
tasks: List[Task] = []
|
||||
|
||||
tasks = Task.search_states(states=[TaskState.waiting])
|
||||
|
||||
tasks_by_id = {x.task_id: x for x in tasks}
|
||||
seen = set()
|
||||
|
||||
not_ready_count = 0
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
|
||||
for bucketed_tasks in buckets.values():
|
||||
for chunk in chunks(bucketed_tasks, MAX_TASKS_PER_SET):
|
||||
result = build_work_set(chunk)
|
||||
if result is None:
|
||||
continue
|
||||
bucket_config, work_set = result
|
||||
|
||||
if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
|
||||
for work_unit in work_set.work_units:
|
||||
task = tasks_by_id[work_unit.task_id]
|
||||
task.state = TaskState.scheduled
|
||||
task.save()
|
||||
seen.add(task.task_id)
|
||||
|
||||
not_ready_count = len(tasks) - len(seen)
|
||||
if not_ready_count > 0:
|
||||
logging.info("tasks not ready: %d", not_ready_count)
|
||||
|
147
src/api-service/tests/test_scheduler.py
Normal file
147
src/api-service/tests/test_scheduler.py
Normal file
@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
from typing import Dict, Generator, List, TypeVar
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from onefuzztypes.enums import OS, ContainerType, TaskType
|
||||
from onefuzztypes.models import TaskConfig, TaskContainers, TaskDetails, TaskPool
|
||||
from onefuzztypes.primitives import Container, PoolName
|
||||
|
||||
from __app__.onefuzzlib.tasks.main import Task
|
||||
from __app__.onefuzzlib.tasks.scheduler import bucket_tasks
|
||||
|
||||
A = TypeVar("A")
|
||||
|
||||
|
||||
def chunks(items: List[A], size: int) -> Generator[List[A], None, None]:
|
||||
return (items[x : x + size] for x in range(0, len(items), size))
|
||||
|
||||
|
||||
class TestTaskBuckets(unittest.TestCase):
|
||||
def build_tasks(self, size: int) -> List[Task]:
|
||||
tasks = []
|
||||
for _ in range(size):
|
||||
task = Task(
|
||||
job_id=UUID(int=0),
|
||||
config=TaskConfig(
|
||||
job_id=UUID(int=0),
|
||||
task=TaskDetails(
|
||||
type=TaskType.libfuzzer_fuzz,
|
||||
duration=1,
|
||||
target_exe="fuzz.exe",
|
||||
target_env={},
|
||||
target_options=[],
|
||||
),
|
||||
pool=TaskPool(pool_name=PoolName("pool"), count=1),
|
||||
containers=[
|
||||
TaskContainers(
|
||||
type=ContainerType.setup, name=Container("setup")
|
||||
)
|
||||
],
|
||||
tags={},
|
||||
colocate=True,
|
||||
),
|
||||
os=OS.linux,
|
||||
)
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
def test_all_colocate(self) -> None:
|
||||
# all tasks should land in one bucket
|
||||
tasks = self.build_tasks(10)
|
||||
for task in tasks:
|
||||
task.config.colocate = True
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
|
||||
for bucket in buckets.values():
|
||||
self.assertEqual(len(bucket), 10)
|
||||
|
||||
self.check_buckets(buckets, tasks, bucket_count=1)
|
||||
|
||||
def test_partial_colocate(self) -> None:
|
||||
# 2 tasks should land on their own, the rest should be colocated into a
|
||||
# single bucket.
|
||||
|
||||
tasks = self.build_tasks(10)
|
||||
|
||||
# a the task came before colocation was defined
|
||||
tasks[0].config.colocate = None
|
||||
|
||||
# a the task shouldn't be colocated
|
||||
tasks[1].config.colocate = False
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
|
||||
lengths = []
|
||||
for bucket in buckets.values():
|
||||
lengths.append(len(bucket))
|
||||
self.assertEqual([1, 1, 8], sorted(lengths))
|
||||
self.check_buckets(buckets, tasks, bucket_count=3)
|
||||
|
||||
def test_all_unique_job(self) -> None:
|
||||
# everything has a unique job_id
|
||||
tasks = self.build_tasks(10)
|
||||
for task in tasks:
|
||||
job_id = uuid4()
|
||||
task.job_id = job_id
|
||||
task.config.job_id = job_id
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
|
||||
for bucket in buckets.values():
|
||||
self.assertEqual(len(bucket), 1)
|
||||
|
||||
self.check_buckets(buckets, tasks, bucket_count=10)
|
||||
|
||||
def test_multiple_job_buckets(self) -> None:
|
||||
# at most 3 tasks per bucket, by job_id
|
||||
tasks = self.build_tasks(10)
|
||||
for task_chunks in chunks(tasks, 3):
|
||||
job_id = uuid4()
|
||||
for task in task_chunks:
|
||||
task.job_id = job_id
|
||||
task.config.job_id = job_id
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
|
||||
for bucket in buckets.values():
|
||||
self.assertLessEqual(len(bucket), 3)
|
||||
|
||||
self.check_buckets(buckets, tasks, bucket_count=4)
|
||||
|
||||
def test_many_buckets(self) -> None:
|
||||
tasks = self.build_tasks(100)
|
||||
job_id = UUID(int=1)
|
||||
for i, task in enumerate(tasks):
|
||||
if i % 2 == 0:
|
||||
task.job_id = job_id
|
||||
task.config.job_id = job_id
|
||||
|
||||
if i % 3 == 0:
|
||||
task.os = OS.windows
|
||||
|
||||
if i % 4 == 0:
|
||||
task.config.containers[0].name = Container("setup2")
|
||||
|
||||
if i % 5 == 0:
|
||||
if task.config.pool:
|
||||
task.config.pool.pool_name = PoolName("alternate-pool")
|
||||
|
||||
buckets = bucket_tasks(tasks)
|
||||
self.check_buckets(buckets, tasks, bucket_count=12)
|
||||
|
||||
def check_buckets(self, buckets: Dict, tasks: List, *, bucket_count: int) -> None:
|
||||
self.assertEqual(len(buckets), bucket_count, "bucket count")
|
||||
|
||||
for task in tasks:
|
||||
seen = False
|
||||
for bucket in buckets.values():
|
||||
if task in bucket:
|
||||
self.assertEqual(seen, False, "task seen in multiple buckets")
|
||||
seen = True
|
||||
self.assertEqual(seen, True, "task not seein in any buckets")
|
@ -803,6 +803,7 @@ class Tasks(Endpoint):
|
||||
target_workers: Optional[int] = None,
|
||||
vm_count: int = 1,
|
||||
preserve_existing_outputs: bool = False,
|
||||
colocate: bool = False,
|
||||
) -> models.Task:
|
||||
"""
|
||||
Create a task
|
||||
@ -846,6 +847,7 @@ class Tasks(Endpoint):
|
||||
pool=models.TaskPool(count=vm_count, pool_name=pool_name),
|
||||
prereq_tasks=prereq_tasks,
|
||||
tags=tags,
|
||||
colocate=colocate,
|
||||
task=models.TaskDetails(
|
||||
analyzer_env=analyzer_env,
|
||||
analyzer_exe=analyzer_exe,
|
||||
|
@ -48,6 +48,8 @@ class Libfuzzer(Command):
|
||||
crash_report_timeout: Optional[int] = None,
|
||||
debug: Optional[List[TaskDebugFlag]] = None,
|
||||
ensemble_sync_delay: Optional[int] = None,
|
||||
colocate_all_tasks: bool = False,
|
||||
colocate_secondary_tasks: bool = True,
|
||||
check_fuzzer_help: bool = True,
|
||||
expect_crash_on_failure: bool = True,
|
||||
) -> None:
|
||||
@ -78,10 +80,13 @@ class Libfuzzer(Command):
|
||||
tags=tags,
|
||||
debug=debug,
|
||||
ensemble_sync_delay=ensemble_sync_delay,
|
||||
colocate=colocate_all_tasks,
|
||||
check_fuzzer_help=check_fuzzer_help,
|
||||
expect_crash_on_failure=expect_crash_on_failure,
|
||||
)
|
||||
|
||||
prereq_tasks = [fuzzer_task.task_id]
|
||||
|
||||
coverage_containers = [
|
||||
(ContainerType.setup, containers[ContainerType.setup]),
|
||||
(ContainerType.coverage, containers[ContainerType.coverage]),
|
||||
@ -100,8 +105,9 @@ class Libfuzzer(Command):
|
||||
target_options=target_options,
|
||||
target_env=target_env,
|
||||
tags=tags,
|
||||
prereq_tasks=[fuzzer_task.task_id],
|
||||
prereq_tasks=prereq_tasks,
|
||||
debug=debug,
|
||||
colocate=colocate_all_tasks or colocate_secondary_tasks,
|
||||
check_fuzzer_help=check_fuzzer_help,
|
||||
)
|
||||
|
||||
@ -126,11 +132,12 @@ class Libfuzzer(Command):
|
||||
target_options=target_options,
|
||||
target_env=target_env,
|
||||
tags=tags,
|
||||
prereq_tasks=[fuzzer_task.task_id],
|
||||
prereq_tasks=prereq_tasks,
|
||||
target_timeout=crash_report_timeout,
|
||||
check_retry_count=check_retry_count,
|
||||
check_fuzzer_help=check_fuzzer_help,
|
||||
debug=debug,
|
||||
colocate=colocate_all_tasks or colocate_secondary_tasks,
|
||||
)
|
||||
|
||||
def basic(
|
||||
@ -160,6 +167,8 @@ class Libfuzzer(Command):
|
||||
notification_config: Optional[NotificationConfig] = None,
|
||||
debug: Optional[List[TaskDebugFlag]] = None,
|
||||
ensemble_sync_delay: Optional[int] = None,
|
||||
colocate_all_tasks: bool = False,
|
||||
colocate_secondary_tasks: bool = True,
|
||||
check_fuzzer_help: bool = True,
|
||||
expect_crash_on_failure: bool = True,
|
||||
) -> Optional[Job]:
|
||||
@ -237,6 +246,8 @@ class Libfuzzer(Command):
|
||||
check_retry_count=check_retry_count,
|
||||
debug=debug,
|
||||
ensemble_sync_delay=ensemble_sync_delay,
|
||||
colocate_all_tasks=colocate_all_tasks,
|
||||
colocate_secondary_tasks=colocate_secondary_tasks,
|
||||
check_fuzzer_help=check_fuzzer_help,
|
||||
expect_crash_on_failure=expect_crash_on_failure,
|
||||
)
|
||||
|
@ -192,6 +192,7 @@ class TaskConfig(BaseModel):
|
||||
containers: List[TaskContainers]
|
||||
tags: Dict[str, str]
|
||||
debug: Optional[List[TaskDebugFlag]]
|
||||
colocate: Optional[bool]
|
||||
|
||||
|
||||
class BlobRef(BaseModel):
|
||||
|
Reference in New Issue
Block a user