Files
onefuzz/src/api-service/__app__/onefuzzlib/tasks/main.py
bmc-msft beea318968 Add User Info to created tasks (#303)
This PR makes user information from JWT tokens available as part of a Task.

Included changes:
* Renamed `verify_token` to `call_if_agent`, since this function is specific to agent token verification
* Renames `is_authorized` to `is_agent`, since this function checks if the token is an agent
* Adds support for unmanaged nodes in `is_agent` (see #133 for information) 
* Saves the user information from the JWT token on task create as part of `TaskConfig`

Note, `TaskConfig` is what is provided to notification templates.  This enables Github issues and ADO work items to tie back to the user that created the task.

Note, while `upn` _usually_ means email for AAD user tokens.  If we were going to make use of the email address, we should perform a graph lookup based on the `oid`, but we're not.
2020-11-13 11:50:52 +00:00

291 lines
9.2 KiB
Python

#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.webhooks import (
WebhookEventTaskCreated,
WebhookEventTaskFailed,
WebhookEventTaskStopped,
)
from ..azure.creds import get_fuzz_storage
from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..pools import Node, Pool, Scaleset
from ..proxy_forward import ProxyForward
from ..webhooks import Webhook
class Task(BASE_TASK, ORMMixin):
def ready_to_schedule(self) -> bool:
if self.config.prereq_tasks:
for task_id in self.config.prereq_tasks:
task = Task.get_by_task_id(task_id)
# if a prereq task fails, then mark this task as failed
if isinstance(task, Error):
self.mark_failed(task)
return False
if task.state not in task.state.has_started():
return False
return True
@classmethod
def create(
cls, config: TaskConfig, job_id: UUID, user_info: UserInfo
) -> Union["Task", Error]:
if config.vm:
os = get_os(config.vm.region, config.vm.image)
elif config.pool:
pool = Pool.get_by_name(config.pool.pool_name)
if isinstance(pool, Error):
return pool
os = pool.os
else:
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
task.save()
Webhook.send_event(
WebhookEventTaskCreated(
job_id=task.job_id,
task_id=task.task_id,
config=config,
user_info=user_info,
)
)
return task
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"heartbeats": ...}
def is_ready(self) -> bool:
if self.config.prereq_tasks:
for prereq_id in self.config.prereq_tasks:
prereq = Task.get_by_task_id(prereq_id)
if isinstance(prereq, Error):
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
self.mark_failed(prereq)
return False
if prereq.state != TaskState.running:
logging.info(
"task is waiting on prereq: %s - %s:",
self.task_id,
prereq.task_id,
)
return False
return True
# At current, the telemetry filter will generate something like this:
#
# {
# 'job_id': 'f4a20fd8-0dcc-4a4e-8804-6ef7df50c978',
# 'task_id': '835f7b3f-43ad-4718-b7e4-d506d9667b09',
# 'state': 'stopped',
# 'config': {
# 'task': {'type': 'libfuzzer_coverage'},
# 'vm': {'count': 1}
# }
# }
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"task_id": ...,
"state": ...,
"config": {"vm": {"count": ...}, "task": {"type": ...}},
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"task_id": ...,
"state": ...,
"error": ...,
}
def init(self) -> None:
create_queue(self.task_id, account_id=get_fuzz_storage())
self.state = TaskState.waiting
self.save()
def stopping(self) -> None:
# TODO: we need to 'unschedule' this task from the existing pools
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
ProxyForward.remove_forward(self.task_id)
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
Node.stop_task(self.task_id)
self.state = TaskState.stopped
self.save()
@classmethod
def search_states(
cls, *, job_id: Optional[UUID] = None, states: Optional[List[TaskState]] = None
) -> List["Task"]:
query: QueryFilter = {}
if job_id:
query["job_id"] = [job_id]
if states:
query["state"] = states
return cls.search(query=query)
@classmethod
def search_expired(cls) -> List["Task"]:
time_filter = "end_time lt datetime'%s'" % datetime.utcnow().isoformat()
return cls.search(
query={"state": TaskState.available()}, raw_unchecked_filter=time_filter
)
@classmethod
def get_by_task_id(cls, task_id: UUID) -> Union[Error, "Task"]:
tasks = cls.search(query={"task_id": [task_id]})
if not tasks:
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unable to find task"])
if len(tasks) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["error identifying task"]
)
task = tasks[0]
return task
@classmethod
def get_tasks_by_pool_name(cls, pool_name: str) -> List["Task"]:
tasks = cls.search_states(states=TaskState.available())
if not tasks:
return []
pool_tasks = []
for task in tasks:
task_pool = task.get_pool()
if not task_pool:
continue
if pool_name == task_pool.name:
pool_tasks.append(task)
return pool_tasks
def mark_stopping(self) -> None:
if self.state in [TaskState.stopped, TaskState.stopping]:
logging.debug(
"ignoring post-task stop calls to stop %s:%s", self.job_id, self.task_id
)
return
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskStopped(
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
)
)
def mark_failed(self, error: Error) -> None:
if self.state in [TaskState.stopped, TaskState.stopping]:
logging.debug(
"ignoring post-task stop failures for %s:%s", self.job_id, self.task_id
)
return
self.error = error
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskFailed(
job_id=self.job_id,
task_id=self.task_id,
error=error,
user_info=self.user_info,
)
)
def get_pool(self) -> Optional[Pool]:
if self.config.pool:
pool = Pool.get_by_name(self.config.pool.pool_name)
if isinstance(pool, Error):
logging.info(
"unable to schedule task to pool: %s - %s", self.task_id, pool
)
return None
return pool
elif self.config.vm:
scalesets = Scaleset.search()
scalesets = [
x
for x in scalesets
if x.vm_sku == self.config.vm.sku and x.image == self.config.vm.image
]
for scaleset in scalesets:
pool = Pool.get_by_name(scaleset.pool_name)
if isinstance(pool, Error):
logging.info(
"unable to schedule task to pool: %s - %s",
self.task_id,
pool,
)
else:
return pool
logging.warning(
"unable to find a scaleset that matches the task prereqs: %s",
self.task_id,
)
return None
def get_repro_vm_config(self) -> Union[TaskVm, None]:
if self.config.vm:
return self.config.vm
if self.config.pool is None:
raise Exception("either pool or vm must be specified: %s" % self.task_id)
pool = Pool.get_by_name(self.config.pool.pool_name)
if isinstance(pool, Error):
logging.info("unable to find pool from task: %s", self.task_id)
return None
for scaleset in Scaleset.search_by_pool(self.config.pool.pool_name):
return TaskVm(
region=scaleset.region,
sku=scaleset.vm_sku,
image=scaleset.image,
)
logging.warning(
"no scalesets are defined for task: %s:%s", self.job_id, self.task_id
)
return None
def on_start(self) -> None:
# try to keep this effectively idempotent
if self.end_time is None:
self.end_time = datetime.utcnow() + timedelta(
hours=self.config.task.duration
)
self.save()
from ..jobs import Job
job = Job.get(self.job_id)
if job:
job.on_start()
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("job_id", "task_id")