mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-13 10:38:08 +00:00
initial public release
This commit is contained in:
231
src/api-service/__app__/onefuzzlib/tasks/main.py
Normal file
231
src/api-service/__app__/onefuzzlib/tasks/main.py
Normal file
@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm
|
||||
|
||||
from ..azure.creds import get_fuzz_storage
|
||||
from ..azure.image import get_os
|
||||
from ..azure.queue import create_queue, delete_queue
|
||||
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from ..pools import Node, Pool, Scaleset
|
||||
from ..proxy_forward import ProxyForward
|
||||
|
||||
|
||||
class Task(BASE_TASK, ORMMixin):
|
||||
def ready_to_schedule(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for task_id in self.config.prereq_tasks:
|
||||
task = Task.get_by_task_id(task_id)
|
||||
# if a prereq task fails, then mark this task as failed
|
||||
if isinstance(task, Error):
|
||||
self.error = task
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
return False
|
||||
|
||||
if task.state not in task.state.has_started():
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: TaskConfig, job_id: UUID) -> Union["Task", Error]:
|
||||
if config.vm:
|
||||
os = get_os(config.vm.region, config.vm.image)
|
||||
elif config.pool:
|
||||
pool = Pool.get_by_name(config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
return pool
|
||||
os = pool.os
|
||||
else:
|
||||
raise Exception("task must have vm or pool")
|
||||
task = cls(config=config, job_id=job_id, os=os)
|
||||
task.save()
|
||||
return task
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"heartbeats": ...}
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
if self.config.prereq_tasks:
|
||||
for prereq_id in self.config.prereq_tasks:
|
||||
prereq = Task.get_by_task_id(prereq_id)
|
||||
if isinstance(prereq, Error):
|
||||
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
|
||||
self.error = prereq
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
return False
|
||||
if prereq.state != TaskState.running:
|
||||
logging.info(
|
||||
"task is waiting on prereq: %s - %s:",
|
||||
self.task_id,
|
||||
prereq.task_id,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# At current, the telemetry filter will generate something like this:
|
||||
#
|
||||
# {
|
||||
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
|
||||
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
|
||||
# 'state': 'stopped',
|
||||
# 'config': {
|
||||
# 'task': {'type': 'libfuzzer_coverage'},
|
||||
# 'vm': {'count': 1}
|
||||
# }
|
||||
# }
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
"config": {"vm": {"count": ...}, "task": {"type": ...}},
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def init(self) -> None:
|
||||
create_queue(self.task_id, account_id=get_fuzz_storage())
|
||||
self.state = TaskState.waiting
|
||||
self.save()
|
||||
|
||||
def stopping(self) -> None:
|
||||
# TODO: we need to tell every node currently working on this task to stop
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
|
||||
self.state = TaskState.stopping
|
||||
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
|
||||
ProxyForward.remove_forward(self.task_id)
|
||||
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
|
||||
Node.stop_task(self.task_id)
|
||||
self.state = TaskState.stopped
|
||||
self.save()
|
||||
|
||||
def queue_stop(self) -> None:
|
||||
self.queue(method=self.stopping)
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
|
||||
) -> List["Task"]:
|
||||
query: QueryFilter = {}
|
||||
if job_id:
|
||||
query["job_id"] = [job_id]
|
||||
if states:
|
||||
query["state"] = states
|
||||
|
||||
return cls.search(query=query)
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["Task"]:
|
||||
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
|
||||
return cls.search(
|
||||
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
|
||||
tasks = cls.search(query={"task_id": [task_id]})
|
||||
if not tasks:
|
||||
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
|
||||
|
||||
if len(tasks) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
|
||||
)
|
||||
task = tasks[0]
|
||||
return task
|
||||
|
||||
def get_pool(self) -> Optional[Pool]:
|
||||
if self.config.pool:
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info(
|
||||
"unable to schedule task to pool: %s - %s", self.task_id, pool
|
||||
)
|
||||
return None
|
||||
return pool
|
||||
elif self.config.vm:
|
||||
scalesets = Scaleset.search()
|
||||
scalesets = [
|
||||
x
|
||||
for x in scalesets
|
||||
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
|
||||
]
|
||||
for scaleset in scalesets:
|
||||
pool = Pool.get_by_name(scaleset.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info(
|
||||
"unable to schedule task to pool: %s - %s",
|
||||
self.task_id,
|
||||
pool,
|
||||
)
|
||||
else:
|
||||
return pool
|
||||
|
||||
logging.warn(
|
||||
"unable to find a scaleset that matches the task prereqs: %s",
|
||||
self.task_id,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_repro_vm_config(self) -> Union[TaskVm, None]:
|
||||
if self.config.vm:
|
||||
return self.config.vm
|
||||
|
||||
if self.config.pool is None:
|
||||
raise Exception("either pool or vm must be specified: %s" % self.task_id)
|
||||
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
logging.info("unable to find pool from task: %s", self.task_id)
|
||||
return None
|
||||
|
||||
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
|
||||
return TaskVm(
|
||||
region=scaleset.region,
|
||||
sku=scaleset.vm_sku,
|
||||
image=scaleset.image,
|
||||
)
|
||||
|
||||
logging.warning(
|
||||
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return None
|
||||
|
||||
def on_start(self) -> None:
|
||||
# try to keep this effectively idempotent
|
||||
|
||||
if self.end_time is None:
|
||||
self.end_time = datetime.utcnow() + timedelta(
|
||||
hours=self.config.task.duration
|
||||
)
|
||||
self.save()
|
||||
|
||||
from ..jobs import Job
|
||||
|
||||
job = Job.get(self.job_id)
|
||||
if job:
|
||||
job.on_start()
|
||||
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("job_id", "task_id")
|
Reference in New Issue
Block a user