Files
onefuzz/src/api-service/__app__/onefuzzlib/webhooks.py

250 lines
7.6 KiB
Python

#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import datetime
import hmac
import logging
from hashlib import sha512
from typing import List, Optional, Tuple
from uuid import UUID, uuid4
import requests
from memoization import cached
from onefuzztypes.enums import ErrorCode, WebhookMessageState
from onefuzztypes.events import Event, EventMessage, EventPing, EventType
from onefuzztypes.models import Error, Result
from onefuzztypes.webhooks import Webhook as BASE_WEBHOOK
from onefuzztypes.webhooks import WebhookMessage
from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
from pydantic import BaseModel, Field
from .__version__ import __version__
from .azure.creds import get_instance_id, get_instance_name
from .azure.queue import queue_object
from .azure.storage import StorageType
from .orm import MappingIntStrAny, ORMMixin
MAX_TRIES = 5
EXPIRE_DAYS = 7
USER_AGENT = "onefuzz-webhook %s" % (__version__)
class WebhookMessageQueueObj(BaseModel):
webhook_id: UUID
event_id: UUID
class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("webhook_id", "event_id")
def export_exclude(self) -> Optional[MappingIntStrAny]:
return {"etag": ..., "timestamp": ...}
@classmethod
def search_expired(cls) -> List["WebhookMessageLog"]:
expire_time = datetime.datetime.utcnow() - datetime.timedelta(days=EXPIRE_DAYS)
time_filter = "Timestamp lt datetime'%s'" % expire_time.isoformat()
return cls.search(raw_unchecked_filter=time_filter)
@classmethod
def process_from_queue(cls, obj: WebhookMessageQueueObj) -> None:
message = cls.get(obj.webhook_id, obj.event_id)
if message is None:
logging.error(
"webhook message missing. %s:%s", obj.webhook_id, obj.event_id
)
return
message.process()
def process(self) -> None:
if self.state in [WebhookMessageState.failed, WebhookMessageState.succeeded]:
logging.info(
"webhook message already handled: %s:%s", self.webhook_id, self.event_id
)
return
self.try_count += 1
logging.debug("sending webhook: %s:%s", self.webhook_id, self.event_id)
if self.send():
self.state = WebhookMessageState.succeeded
self.save()
logging.info("sent webhook event: %s:%s", self.webhook_id, self.event_id)
return
if self.try_count < MAX_TRIES:
self.state = WebhookMessageState.retrying
self.save()
self.queue_webhook()
logging.warning(
"sending webhook event failed, re-queued. %s:%s",
self.webhook_id,
self.event_id,
)
else:
self.state = WebhookMessageState.failed
self.save()
logging.warning(
"sending webhook event failed %d times. %s:%s",
self.try_count,
self.webhook_id,
self.event_id,
)
def send(self) -> bool:
webhook = Webhook.get_by_id(self.webhook_id)
if isinstance(webhook, Error):
logging.error(
"webhook no longer exists: %s:%s", self.webhook_id, self.event_id
)
return False
try:
return webhook.send(self)
except Exception as err:
logging.error(
"webhook failed with exception: %s:%s - %s",
self.webhook_id,
self.event_id,
err,
)
return False
def queue_webhook(self) -> None:
obj = WebhookMessageQueueObj(webhook_id=self.webhook_id, event_id=self.event_id)
if self.state == WebhookMessageState.queued:
visibility_timeout = 0
elif self.state == WebhookMessageState.retrying:
visibility_timeout = 30
else:
logging.error(
"invalid WebhookMessage queue state, not queuing. %s:%s - %s",
self.webhook_id,
self.event_id,
self.state,
)
return
queue_object(
"webhooks",
obj,
StorageType.config,
visibility_timeout=visibility_timeout,
)
class Webhook(BASE_WEBHOOK, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("webhook_id", "name")
@classmethod
def send_event(cls, event_message: EventMessage) -> None:
for webhook in get_webhooks_cached():
if event_message.event_type not in webhook.event_types:
continue
webhook._add_event(event_message)
@classmethod
def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
webhooks = cls.search(query={"webhook_id": [webhook_id]})
if not webhooks:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["unable to find webhook"]
)
if len(webhooks) != 1:
return Error(
code=ErrorCode.INVALID_REQUEST,
errors=["error identifying Notification"],
)
webhook = webhooks[0]
return webhook
def _add_event(self, event_message: EventMessage) -> None:
message = WebhookMessageLog(
webhook_id=self.webhook_id,
event_id=event_message.event_id,
event_type=event_message.event_type,
event=event_message.event,
instance_id=event_message.instance_id,
instance_name=event_message.instance_name,
)
message.save()
message.queue_webhook()
def ping(self) -> EventPing:
ping = EventPing(ping_id=uuid4())
self._add_event(
EventMessage(
event_type=EventType.ping,
event=ping,
instance_id=get_instance_id(),
instance_name=get_instance_name(),
)
)
return ping
def send(self, message_log: WebhookMessageLog) -> bool:
if self.url is None:
raise Exception("webhook URL incorrectly removed: %s" % self.webhook_id)
data, digest = build_message(
webhook_id=self.webhook_id,
event_id=message_log.event_id,
event_type=message_log.event_type,
event=message_log.event,
secret_token=self.secret_token,
)
headers = {"Content-type": "application/json", "User-Agent": USER_AGENT}
if digest:
headers["X-Onefuzz-Digest"] = digest
response = requests.post(
self.url,
data=data,
headers=headers,
)
return response.ok
def build_message(
*,
webhook_id: UUID,
event_id: UUID,
event_type: EventType,
event: Event,
secret_token: Optional[str] = None,
) -> Tuple[bytes, Optional[str]]:
data = (
WebhookMessage(
webhook_id=webhook_id,
event_id=event_id,
event_type=event_type,
event=event,
instance_id=get_instance_id(),
instance_name=get_instance_name(),
)
.json(sort_keys=True, exclude_none=True)
.encode()
)
digest = None
if secret_token:
digest = hmac.new(secret_token.encode(), msg=data, digestmod=sha512).hexdigest()
return (data, digest)
@cached(ttl=30)
def get_webhooks_cached() -> List[Webhook]:
return Webhook.search()