mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
Event based webhooks (#296)
This commit is contained in:
@ -67,8 +67,7 @@ class Job(BASE_JOB, ORMMixin):
|
||||
|
||||
if not_stopped:
|
||||
for task in not_stopped:
|
||||
task.state = TaskState.stopping
|
||||
task.save()
|
||||
task.mark_stopping()
|
||||
else:
|
||||
self.state = JobState.stopped
|
||||
self.save()
|
||||
|
@ -251,16 +251,14 @@ class ORMMixin(ModelMixin):
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {}
|
||||
|
||||
def event(self) -> Any:
|
||||
return self.raw(exclude_none=True, include=self.event_include())
|
||||
|
||||
def telemetry(self) -> Any:
|
||||
return self.raw(exclude_none=True, include=self.telemetry_include())
|
||||
|
||||
def _event_as_needed(self) -> None:
|
||||
# Upon ORM save, if the object returns event data, we'll send it to the
|
||||
# dashboard event subsystem
|
||||
data = self.event()
|
||||
|
||||
data = self.raw(exclude_none=True, include=self.event_include())
|
||||
if not data:
|
||||
return
|
||||
add_event(self.table_name(), data)
|
||||
@ -370,7 +368,11 @@ class ORMMixin(ModelMixin):
|
||||
annotation = inspect.signature(cls).parameters[key].annotation
|
||||
|
||||
if inspect.isclass(annotation):
|
||||
if issubclass(annotation, BaseModel) or issubclass(annotation, dict):
|
||||
if (
|
||||
issubclass(annotation, BaseModel)
|
||||
or issubclass(annotation, dict)
|
||||
or issubclass(annotation, list)
|
||||
):
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
@ -381,9 +383,9 @@ class ORMMixin(ModelMixin):
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` annotation is a class
|
||||
# according to `inspect.isclass`.
|
||||
if getattr(annotation, "__origin__", None) == dict:
|
||||
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` and `List[_]` annotations
|
||||
# are a class according to `inspect.isclass`.
|
||||
if getattr(annotation, "__origin__", None) in [dict, list]:
|
||||
data[key] = json.loads(data[key])
|
||||
continue
|
||||
|
||||
|
@ -12,6 +12,11 @@ from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm
|
||||
from onefuzztypes.webhooks import (
|
||||
WebhookEventTaskCreated,
|
||||
WebhookEventTaskFailed,
|
||||
WebhookEventTaskStopped,
|
||||
)
|
||||
|
||||
from ..azure.creds import get_fuzz_storage
|
||||
from ..azure.image import get_os
|
||||
@ -19,6 +24,7 @@ from ..azure.queue import create_queue, delete_queue
|
||||
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from ..pools import Node, Pool, Scaleset
|
||||
from ..proxy_forward import ProxyForward
|
||||
from ..webhooks import Webhook
|
||||
|
||||
|
||||
class Task(BASE_TASK, ORMMixin):
|
||||
@ -28,9 +34,7 @@ class Task(BASE_TASK, ORMMixin):
|
||||
task = Task.get_by_task_id(task_id)
|
||||
# if a prereq task fails, then mark this task as failed
|
||||
if isinstance(task, Error):
|
||||
self.error = task
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
self.mark_failed(task)
|
||||
return False
|
||||
|
||||
if task.state not in task.state.has_started():
|
||||
@ -50,6 +54,11 @@ class Task(BASE_TASK, ORMMixin):
|
||||
raise Exception("task must have vm or pool")
|
||||
task = cls(config=config, job_id=job_id, os=os)
|
||||
task.save()
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskCreated(
|
||||
job_id=task.job_id, task_id=task.task_id, config=config
|
||||
)
|
||||
)
|
||||
return task
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
@ -61,9 +70,7 @@ class Task(BASE_TASK, ORMMixin):
|
||||
prereq = Task.get_by_task_id(prereq_id)
|
||||
if isinstance(prereq, Error):
|
||||
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
|
||||
self.error = prereq
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
self.mark_failed(prereq)
|
||||
return False
|
||||
if prereq.state != TaskState.running:
|
||||
logging.info(
|
||||
@ -110,7 +117,6 @@ class Task(BASE_TASK, ORMMixin):
|
||||
def stopping(self) -> None:
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
|
||||
self.state = TaskState.stopping
|
||||
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
|
||||
ProxyForward.remove_forward(self.task_id)
|
||||
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
|
||||
@ -168,9 +174,17 @@ class Task(BASE_TASK, ORMMixin):
|
||||
return pool_tasks
|
||||
|
||||
def mark_stopping(self) -> None:
|
||||
if self.state not in [TaskState.stopped, TaskState.stopping]:
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
if self.state in [TaskState.stopped, TaskState.stopping]:
|
||||
logging.debug(
|
||||
"ignoring post-task stop calls to stop %s:%s", self.job_id, self.task_id
|
||||
)
|
||||
return
|
||||
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskStopped(job_id=self.job_id, task_id=self.task_id)
|
||||
)
|
||||
|
||||
def mark_failed(self, error: Error) -> None:
|
||||
if self.state in [TaskState.stopped, TaskState.stopping]:
|
||||
@ -183,6 +197,12 @@ class Task(BASE_TASK, ORMMixin):
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskFailed(
|
||||
job_id=self.job_id, task_id=self.task_id, error=error
|
||||
)
|
||||
)
|
||||
|
||||
def get_pool(self) -> Optional[Pool]:
|
||||
if self.config.pool:
|
||||
pool = Pool.get_by_name(self.config.pool.pool_name)
|
||||
|
245
src/api-service/__app__/onefuzzlib/webhooks.py
Normal file
245
src/api-service/__app__/onefuzzlib/webhooks.py
Normal file
@ -0,0 +1,245 @@
|
||||
import datetime
|
||||
import hmac
|
||||
import logging
|
||||
from hashlib import sha512
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode, WebhookEventType, WebhookMessageState
|
||||
from onefuzztypes.models import Error, Result
|
||||
from onefuzztypes.webhooks import Webhook as BASE_WEBHOOK
|
||||
from onefuzztypes.webhooks import (
|
||||
WebhookEvent,
|
||||
WebhookEventPing,
|
||||
WebhookEventTaskCreated,
|
||||
WebhookEventTaskFailed,
|
||||
WebhookEventTaskStopped,
|
||||
WebhookMessage,
|
||||
)
|
||||
from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .__version__ import __version__
|
||||
from .azure.creds import get_func_storage
|
||||
from .azure.queue import queue_object
|
||||
from .orm import ORMMixin
|
||||
|
||||
MAX_TRIES = 5
|
||||
EXPIRE_DAYS = 7
|
||||
USER_AGENT = "onefuzz-webhook %s" % (__version__)
|
||||
|
||||
|
||||
class WebhookMessageQueueObj(BaseModel):
|
||||
webhook_id: UUID
|
||||
event_id: UUID
|
||||
|
||||
|
||||
class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("webhook_id", "event_id")
|
||||
|
||||
@classmethod
|
||||
def search_expired(cls) -> List["WebhookMessageLog"]:
|
||||
expire_time = datetime.datetime.utcnow() - datetime.timedelta(days=EXPIRE_DAYS)
|
||||
time_filter = "Timestamp lt datetime'%s'" % expire_time.isoformat()
|
||||
return cls.search(raw_unchecked_filter=time_filter)
|
||||
|
||||
@classmethod
|
||||
def process_from_queue(cls, obj: WebhookMessageQueueObj) -> None:
|
||||
message = cls.get(obj.webhook_id, obj.event_id)
|
||||
if message is None:
|
||||
logging.error(
|
||||
"webhook message missing. %s:%s", obj.webhook_id, obj.event_id
|
||||
)
|
||||
return
|
||||
message.process()
|
||||
|
||||
def process(self) -> None:
|
||||
if self.state in [WebhookMessageState.failed, WebhookMessageState.succeeded]:
|
||||
logging.info(
|
||||
"webhook message already handled: %s:%s", self.webhook_id, self.event_id
|
||||
)
|
||||
return
|
||||
|
||||
self.try_count += 1
|
||||
|
||||
logging.debug("sending webhook: %s:%s", self.webhook_id, self.event_id)
|
||||
if self.send():
|
||||
self.state = WebhookMessageState.succeeded
|
||||
self.save()
|
||||
logging.info("sent webhook event: %s:%s", self.webhook_id, self.event_id)
|
||||
return
|
||||
|
||||
if self.try_count < MAX_TRIES:
|
||||
self.state = WebhookMessageState.retrying
|
||||
self.save()
|
||||
self.queue_webhook()
|
||||
logging.warning(
|
||||
"sending webhook event failed, re-queued. %s:%s",
|
||||
self.webhook_id,
|
||||
self.event_id,
|
||||
)
|
||||
else:
|
||||
self.state = WebhookMessageState.failed
|
||||
self.save()
|
||||
logging.warning(
|
||||
"sending webhook event failed %d times. %s:%s",
|
||||
self.try_count,
|
||||
self.webhook_id,
|
||||
self.event_id,
|
||||
)
|
||||
|
||||
def send(self) -> bool:
|
||||
webhook = Webhook.get_by_id(self.webhook_id)
|
||||
if isinstance(webhook, Error):
|
||||
logging.error(
|
||||
"webhook no longer exists: %s:%s", self.webhook_id, self.event_id
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
return webhook.send(self)
|
||||
except Exception as err:
|
||||
logging.error(
|
||||
"webhook failed with exception: %s:%s - %s",
|
||||
self.webhook_id,
|
||||
self.event_id,
|
||||
err,
|
||||
)
|
||||
return False
|
||||
|
||||
def queue_webhook(self) -> None:
|
||||
obj = WebhookMessageQueueObj(webhook_id=self.webhook_id, event_id=self.event_id)
|
||||
|
||||
if self.state == WebhookMessageState.queued:
|
||||
visibility_timeout = 0
|
||||
elif self.state == WebhookMessageState.retrying:
|
||||
visibility_timeout = 30
|
||||
else:
|
||||
logging.error(
|
||||
"invalid WebhookMessage queue state, not queuing. %s:%s - %s",
|
||||
self.webhook_id,
|
||||
self.event_id,
|
||||
self.state,
|
||||
)
|
||||
return
|
||||
|
||||
queue_object(
|
||||
"webhooks",
|
||||
obj,
|
||||
visibility_timeout=visibility_timeout,
|
||||
account_id=get_func_storage(),
|
||||
)
|
||||
|
||||
|
||||
def get_event_type(event: WebhookEvent) -> WebhookEventType:
|
||||
events = {
|
||||
WebhookEventTaskCreated: WebhookEventType.task_created,
|
||||
WebhookEventTaskFailed: WebhookEventType.task_failed,
|
||||
WebhookEventTaskStopped: WebhookEventType.task_stopped,
|
||||
WebhookEventPing: WebhookEventType.ping,
|
||||
}
|
||||
|
||||
for event_class in events:
|
||||
if isinstance(event, event_class):
|
||||
return events[event_class]
|
||||
|
||||
raise NotImplementedError("unsupported event type: %s" % event)
|
||||
|
||||
|
||||
class Webhook(BASE_WEBHOOK, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("webhook_id", "name")
|
||||
|
||||
@classmethod
|
||||
def send_event(cls, event: WebhookEvent) -> None:
|
||||
event_type = get_event_type(event)
|
||||
for webhook in get_webhooks_cached():
|
||||
if event_type not in webhook.event_types:
|
||||
continue
|
||||
|
||||
webhook._add_event(event_type, event)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
|
||||
webhooks = cls.search(query={"webhook_id": [webhook_id]})
|
||||
if not webhooks:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST, errors=["unable to find webhook"]
|
||||
)
|
||||
|
||||
if len(webhooks) != 1:
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["error identifying Notification"],
|
||||
)
|
||||
webhook = webhooks[0]
|
||||
return webhook
|
||||
|
||||
def _add_event(self, event_type: WebhookEventType, event: WebhookEvent) -> None:
|
||||
message = WebhookMessageLog(
|
||||
webhook_id=self.webhook_id,
|
||||
event_type=event_type,
|
||||
event=event,
|
||||
)
|
||||
message.save()
|
||||
message.queue_webhook()
|
||||
|
||||
def ping(self) -> WebhookEventPing:
|
||||
ping = WebhookEventPing()
|
||||
self._add_event(WebhookEventType.ping, ping)
|
||||
return ping
|
||||
|
||||
def send(self, message_log: WebhookMessageLog) -> bool:
|
||||
if self.url is None:
|
||||
raise Exception("webhook URL incorrectly removed: %s" % self.webhook_id)
|
||||
|
||||
data, digest = build_message(
|
||||
webhook_id=self.webhook_id,
|
||||
event_id=message_log.event_id,
|
||||
event_type=message_log.event_type,
|
||||
event=message_log.event,
|
||||
secret_token=self.secret_token,
|
||||
)
|
||||
|
||||
headers = {"Content-type": "application/json", "User-Agent": USER_AGENT}
|
||||
|
||||
if digest:
|
||||
headers["X-Onefuzz-Digest"] = digest
|
||||
|
||||
response = requests.post(
|
||||
self.url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
return response.ok
|
||||
|
||||
|
||||
def build_message(
|
||||
*,
|
||||
webhook_id: UUID,
|
||||
event_id: UUID,
|
||||
event_type: WebhookEventType,
|
||||
event: WebhookEvent,
|
||||
secret_token: Optional[str] = None,
|
||||
) -> Tuple[bytes, Optional[str]]:
|
||||
data = (
|
||||
WebhookMessage(
|
||||
webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event
|
||||
)
|
||||
.json(sort_keys=True)
|
||||
.encode()
|
||||
)
|
||||
digest = None
|
||||
if secret_token:
|
||||
digest = hmac.new(secret_token.encode(), msg=data, digestmod=sha512).hexdigest()
|
||||
return (data, digest)
|
||||
|
||||
|
||||
@cached(ttl=30)
|
||||
def get_webhooks_cached() -> List[Webhook]:
|
||||
return Webhook.search()
|
Reference in New Issue
Block a user