mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-13 10:38:08 +00:00
232 lines
7.6 KiB
Python
232 lines
7.6 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Optional, Tuple, Union
|
|
from uuid import UUID
|
|
|
|
from onefuzztypes.enums import ErrorCode, TaskState
|
|
from onefuzztypes.models import Error
|
|
from onefuzztypes.models import Task as BASE_TASK
|
|
from onefuzztypes.models import TaskConfig, TaskVm
|
|
|
|
from ..azure.creds import get_fuzz_storage
|
|
from ..azure.image import get_os
|
|
from ..azure.queue import create_queue, delete_queue
|
|
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
|
from ..pools import Node, Pool, Scaleset
|
|
from ..proxy_forward import ProxyForward
|
|
|
|
|
|
class Task(BASE_TASK, ORMMixin):
|
|
def ready_to_schedule(self) -> bool:
|
|
if self.config.prereq_tasks:
|
|
for task_id in self.config.prereq_tasks:
|
|
task = Task.get_by_task_id(task_id)
|
|
# if a prereq task fails, then mark this task as failed
|
|
if isinstance(task, Error):
|
|
self.error = task
|
|
self.state = TaskState.stopping
|
|
self.save()
|
|
return False
|
|
|
|
if task.state not in task.state.has_started():
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
|
|
if config.vm:
|
|
os = get_os(config.vm.region, config.vm.image)
|
|
elif config.pool:
|
|
pool = Pool.get_by_name(config.pool.pool_name)
|
|
if isinstance(pool, Error):
|
|
return pool
|
|
os = pool.os
|
|
else:
|
|
raise Exception("task must have vm or pool")
|
|
task = cls(config=config, job_id=job_id, os=os)
|
|
task.save()
|
|
return task
|
|
|
|
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
|
return {"heartbeats": ...}
|
|
|
|
def is_ready(self) -> bool:
|
|
if self.config.prereq_tasks:
|
|
for prereq_id in self.config.prereq_tasks:
|
|
prereq = Task.get_by_task_id(prereq_id)
|
|
if isinstance(prereq, Error):
|
|
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
|
|
self.error = prereq
|
|
self.state = TaskState.stopping
|
|
self.save()
|
|
return False
|
|
if prereq.state != TaskState.running:
|
|
logging.info(
|
|
"task is waiting on prereq: %s - %s:",
|
|
self.task_id,
|
|
prereq.task_id,
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
# At current, the telemetry filter will generate something like this:
|
|
#
|
|
# {
|
|
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
|
|
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
|
|
# 'state': 'stopped',
|
|
# 'config': {
|
|
# 'task': {'type': 'libfuzzer_coverage'},
|
|
# 'vm': {'count': 1}
|
|
# }
|
|
# }
|
|
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
|
return {
|
|
"job_id": ...,
|
|
"task_id": ...,
|
|
"state": ...,
|
|
"config": {"vm": {"count": ...}, "task": {"type": ...}},
|
|
}
|
|
|
|
def event_include(self) -> Optional[MappingIntStrAny]:
|
|
return {
|
|
"job_id": ...,
|
|
"task_id": ...,
|
|
"state": ...,
|
|
"error": ...,
|
|
}
|
|
|
|
def init(self) -> None:
|
|
create_queue(self.task_id, account_id=get_fuzz_storage())
|
|
self.state = TaskState.waiting
|
|
self.save()
|
|
|
|
def stopping(self) -> None:
|
|
# TODO: we need to tell every node currently working on this task to stop
|
|
# TODO: we need to 'unschedule' this task from the existing pools
|
|
|
|
self.state = TaskState.stopping
|
|
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
|
|
ProxyForward.remove_forward(self.task_id)
|
|
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
|
|
Node.stop_task(self.task_id)
|
|
self.state = TaskState.stopped
|
|
self.save()
|
|
|
|
def queue_stop(self) -> None:
|
|
self.queue(method=self.stopping)
|
|
|
|
@classmethod
|
|
def search_states(
|
|
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
|
|
) -> List["Task"]:
|
|
query: QueryFilter = {}
|
|
if job_id:
|
|
query["job_id"] = [job_id]
|
|
if states:
|
|
query["state"] = states
|
|
|
|
return cls.search(query=query)
|
|
|
|
@classmethod
|
|
def search_expired(cls) -> List["Task"]:
|
|
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
|
return cls.search(
|
|
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
|
|
)
|
|
|
|
@classmethod
|
|
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
|
|
tasks = cls.search(query={"task_id": [task_id]})
|
|
if not tasks:
|
|
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
|
|
|
|
if len(tasks) != 1:
|
|
return Error(
|
|
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
|
|
)
|
|
task = tasks[0]
|
|
return task
|
|
|
|
def get_pool(self) -> Optional[Pool]:
|
|
if self.config.pool:
|
|
pool = Pool.get_by_name(self.config.pool.pool_name)
|
|
if isinstance(pool, Error):
|
|
logging.info(
|
|
"unable to schedule task to pool: %s - %s", self.task_id, pool
|
|
)
|
|
return None
|
|
return pool
|
|
elif self.config.vm:
|
|
scalesets = Scaleset.search()
|
|
scalesets = [
|
|
x
|
|
for x in scalesets
|
|
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
|
|
]
|
|
for scaleset in scalesets:
|
|
pool = Pool.get_by_name(scaleset.pool_name)
|
|
if isinstance(pool, Error):
|
|
logging.info(
|
|
"unable to schedule task to pool: %s - %s",
|
|
self.task_id,
|
|
pool,
|
|
)
|
|
else:
|
|
return pool
|
|
|
|
logging.warning(
|
|
"unable to find a scaleset that matches the task prereqs: %s",
|
|
self.task_id,
|
|
)
|
|
return None
|
|
|
|
def get_repro_vm_config(self) -> Union[TaskVm, None]:
|
|
if self.config.vm:
|
|
return self.config.vm
|
|
|
|
if self.config.pool is None:
|
|
raise Exception("either pool or vm must be specified: %s" % self.task_id)
|
|
|
|
pool = Pool.get_by_name(self.config.pool.pool_name)
|
|
if isinstance(pool, Error):
|
|
logging.info("unable to find pool from task: %s", self.task_id)
|
|
return None
|
|
|
|
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
|
|
return TaskVm(
|
|
region=scaleset.region,
|
|
sku=scaleset.vm_sku,
|
|
image=scaleset.image,
|
|
)
|
|
|
|
logging.warning(
|
|
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
|
|
)
|
|
return None
|
|
|
|
def on_start(self) -> None:
|
|
# try to keep this effectively idempotent
|
|
|
|
if self.end_time is None:
|
|
self.end_time = datetime.utcnow() + timedelta(
|
|
hours=self.config.task.duration
|
|
)
|
|
self.save()
|
|
|
|
from ..jobs import Job
|
|
|
|
job = Job.get(self.job_id)
|
|
if job:
|
|
job.on_start()
|
|
|
|
@classmethod
|
|
def key_fields(cls) -> Tuple[str, str]:
|
|
return ("job_id", "task_id")
|