Event based webhooks (#296)

This commit is contained in:
bmc-msft
2020-11-12 17:44:42 -05:00
committed by GitHub
parent 693c988854
commit 31f099d3d4
24 changed files with 2133 additions and 35 deletions

View File

@ -67,8 +67,7 @@ class Job(BASE_JOB, ORMMixin):
if not_stopped:
for task in not_stopped:
task.state = TaskState.stopping
task.save()
task.mark_stopping()
else:
self.state = JobState.stopped
self.save()

View File

@ -251,16 +251,14 @@ class ORMMixin(ModelMixin):
def event_include(self) -> Optional[MappingIntStrAny]:
return {}
def event(self) -> Any:
return self.raw(exclude_none=True, include=self.event_include())
def telemetry(self) -> Any:
return self.raw(exclude_none=True, include=self.telemetry_include())
def _event_as_needed(self) -> None:
# Upon ORM save, if the object returns event data, we'll send it to the
# dashboard event subsystem
data = self.event()
data = self.raw(exclude_none=True, include=self.event_include())
if not data:
return
add_event(self.table_name(), data)
@ -370,7 +368,11 @@ class ORMMixin(ModelMixin):
annotation = inspect.signature(cls).parameters[key].annotation
if inspect.isclass(annotation):
if issubclass(annotation, BaseModel) or issubclass(annotation, dict):
if (
issubclass(annotation, BaseModel)
or issubclass(annotation, dict)
or issubclass(annotation, list)
):
data[key] = json.loads(data[key])
continue
@ -381,9 +383,9 @@ class ORMMixin(ModelMixin):
data[key] = json.loads(data[key])
continue
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` annotation is a class
# according to `inspect.isclass`.
if getattr(annotation, "__origin__", None) == dict:
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` and `List[_]` annotations
# are a class according to `inspect.isclass`.
if getattr(annotation, "__origin__", None) in [dict, list]:
data[key] = json.loads(data[key])
continue

View File

@ -12,6 +12,11 @@ from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm
from onefuzztypes.webhooks import (
WebhookEventTaskCreated,
WebhookEventTaskFailed,
WebhookEventTaskStopped,
)
from ..azure.creds import get_fuzz_storage
from ..azure.image import get_os
@ -19,6 +24,7 @@ from ..azure.queue import create_queue, delete_queue
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..pools import Node, Pool, Scaleset
from ..proxy_forward import ProxyForward
from ..webhooks import Webhook
class Task(BASE_TASK, ORMMixin):
@ -28,9 +34,7 @@ class Task(BASE_TASK, ORMMixin):
task = Task.get_by_task_id(task_id)
# if a prereq task fails, then mark this task as failed
if isinstance(task, Error):
self.error = task
self.state = TaskState.stopping
self.save()
self.mark_failed(task)
return False
if task.state not in task.state.has_started():
@ -50,6 +54,11 @@ class Task(BASE_TASK, ORMMixin):
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os)
task.save()
Webhook.send_event(
WebhookEventTaskCreated(
job_id=task.job_id, task_id=task.task_id, config=config
)
)
return task
def save_exclude(self) -> Optional[MappingIntStrAny]:
@ -61,9 +70,7 @@ class Task(BASE_TASK, ORMMixin):
prereq = Task.get_by_task_id(prereq_id)
if isinstance(prereq, Error):
logging.info("task prereq has error: %s - %s", self.task_id, prereq)
self.error = prereq
self.state = TaskState.stopping
self.save()
self.mark_failed(prereq)
return False
if prereq.state != TaskState.running:
logging.info(
@ -110,7 +117,6 @@ class Task(BASE_TASK, ORMMixin):
def stopping(self) -> None:
# TODO: we need to 'unschedule' this task from the existing pools
self.state = TaskState.stopping
logging.info("stopping task: %s:%s", self.job_id, self.task_id)
ProxyForward.remove_forward(self.task_id)
delete_queue(str(self.task_id), account_id=get_fuzz_storage())
@ -168,9 +174,17 @@ class Task(BASE_TASK, ORMMixin):
return pool_tasks
def mark_stopping(self) -> None:
if self.state not in [TaskState.stopped, TaskState.stopping]:
self.state = TaskState.stopping
self.save()
if self.state in [TaskState.stopped, TaskState.stopping]:
logging.debug(
"ignoring post-task stop calls to stop %s:%s", self.job_id, self.task_id
)
return
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskStopped(job_id=self.job_id, task_id=self.task_id)
)
def mark_failed(self, error: Error) -> None:
if self.state in [TaskState.stopped, TaskState.stopping]:
@ -183,6 +197,12 @@ class Task(BASE_TASK, ORMMixin):
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskFailed(
job_id=self.job_id, task_id=self.task_id, error=error
)
)
def get_pool(self) -> Optional[Pool]:
if self.config.pool:
pool = Pool.get_by_name(self.config.pool.pool_name)

View File

@ -0,0 +1,245 @@
import datetime
import hmac
import logging
from hashlib import sha512
from typing import List, Optional, Tuple
from uuid import UUID
import requests
from memoization import cached
from onefuzztypes.enums import ErrorCode, WebhookEventType, WebhookMessageState
from onefuzztypes.models import Error, Result
from onefuzztypes.webhooks import Webhook as BASE_WEBHOOK
from onefuzztypes.webhooks import (
WebhookEvent,
WebhookEventPing,
WebhookEventTaskCreated,
WebhookEventTaskFailed,
WebhookEventTaskStopped,
WebhookMessage,
)
from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
from pydantic import BaseModel
from .__version__ import __version__
from .azure.creds import get_func_storage
from .azure.queue import queue_object
from .orm import ORMMixin
MAX_TRIES = 5
EXPIRE_DAYS = 7
USER_AGENT = "onefuzz-webhook %s" % (__version__)
class WebhookMessageQueueObj(BaseModel):
webhook_id: UUID
event_id: UUID
class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("webhook_id", "event_id")
@classmethod
def search_expired(cls) -> List["WebhookMessageLog"]:
expire_time = datetime.datetime.utcnow() - datetime.timedelta(days=EXPIRE_DAYS)
time_filter = "Timestamp lt datetime'%s'" % expire_time.isoformat()
return cls.search(raw_unchecked_filter=time_filter)
@classmethod
def process_from_queue(cls, obj: WebhookMessageQueueObj) -> None:
message = cls.get(obj.webhook_id, obj.event_id)
if message is None:
logging.error(
"webhook message missing. %s:%s", obj.webhook_id, obj.event_id
)
return
message.process()
def process(self) -> None:
if self.state in [WebhookMessageState.failed, WebhookMessageState.succeeded]:
logging.info(
"webhook message already handled: %s:%s", self.webhook_id, self.event_id
)
return
self.try_count += 1
logging.debug("sending webhook: %s:%s", self.webhook_id, self.event_id)
if self.send():
self.state = WebhookMessageState.succeeded
self.save()
logging.info("sent webhook event: %s:%s", self.webhook_id, self.event_id)
return
if self.try_count < MAX_TRIES:
self.state = WebhookMessageState.retrying
self.save()
self.queue_webhook()
logging.warning(
"sending webhook event failed, re-queued. %s:%s",
self.webhook_id,
self.event_id,
)
else:
self.state = WebhookMessageState.failed
self.save()
logging.warning(
"sending webhook event failed %d times. %s:%s",
self.try_count,
self.webhook_id,
self.event_id,
)
def send(self) -> bool:
webhook = Webhook.get_by_id(self.webhook_id)
if isinstance(webhook, Error):
logging.error(
"webhook no longer exists: %s:%s", self.webhook_id, self.event_id
)
return False
try:
return webhook.send(self)
except Exception as err:
logging.error(
"webhook failed with exception: %s:%s - %s",
self.webhook_id,
self.event_id,
err,
)
return False
def queue_webhook(self) -> None:
obj = WebhookMessageQueueObj(webhook_id=self.webhook_id, event_id=self.event_id)
if self.state == WebhookMessageState.queued:
visibility_timeout = 0
elif self.state == WebhookMessageState.retrying:
visibility_timeout = 30
else:
logging.error(
"invalid WebhookMessage queue state, not queuing. %s:%s - %s",
self.webhook_id,
self.event_id,
self.state,
)
return
queue_object(
"webhooks",
obj,
visibility_timeout=visibility_timeout,
account_id=get_func_storage(),
)
def get_event_type(event: WebhookEvent) -> WebhookEventType:
events = {
WebhookEventTaskCreated: WebhookEventType.task_created,
WebhookEventTaskFailed: WebhookEventType.task_failed,
WebhookEventTaskStopped: WebhookEventType.task_stopped,
WebhookEventPing: WebhookEventType.ping,
}
for event_class in events:
if isinstance(event, event_class):
return events[event_class]
raise NotImplementedError("unsupported event type: %s" % event)
class Webhook(BASE_WEBHOOK, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("webhook_id", "name")
@classmethod
def send_event(cls, event: WebhookEvent) -> None:
event_type = get_event_type(event)
for webhook in get_webhooks_cached():
if event_type not in webhook.event_types:
continue
webhook._add_event(event_type, event)
@classmethod
def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
webhooks = cls.search(query={"webhook_id": [webhook_id]})
if not webhooks:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["unable to find webhook"]
)
if len(webhooks) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["error identifying Notification"],
)
webhook = webhooks[0]
return webhook
def _add_event(self, event_type: WebhookEventType, event: WebhookEvent) -> None:
message = WebhookMessageLog(
webhook_id=self.webhook_id,
event_type=event_type,
event=event,
)
message.save()
message.queue_webhook()
def ping(self) -> WebhookEventPing:
ping = WebhookEventPing()
self._add_event(WebhookEventType.ping, ping)
return ping
def send(self, message_log: WebhookMessageLog) -> bool:
if self.url is None:
raise Exception("webhook URL incorrectly removed: %s" % self.webhook_id)
data, digest = build_message(
webhook_id=self.webhook_id,
event_id=message_log.event_id,
event_type=message_log.event_type,
event=message_log.event,
secret_token=self.secret_token,
)
headers = {"Content-type": "application/json", "User-Agent": USER_AGENT}
if digest:
headers["X-Onefuzz-Digest"] = digest
response = requests.post(
self.url,
data=data,
headers=headers,
)
return response.ok
def build_message(
*,
webhook_id: UUID,
event_id: UUID,
event_type: WebhookEventType,
event: WebhookEvent,
secret_token: Optional[str] = None,
) -> Tuple[bytes, Optional[str]]:
data = (
WebhookMessage(
webhook_id=webhook_id, event_id=event_id, event_type=event_type, event=event
)
.json(sort_keys=True)
.encode()
)
digest = None
if secret_token:
digest = hmac.new(secret_token.encode(), msg=data, digestmod=sha512).hexdigest()
return (data, digest)
@cached(ttl=30)
def get_webhooks_cached() -> List[Webhook]:
return Webhook.search()