mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 04:18:07 +00:00
250 lines
7.6 KiB
Python
250 lines
7.6 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import datetime
|
|
import hmac
|
|
import logging
|
|
from hashlib import sha512
|
|
from typing import List, Optional, Tuple
|
|
from uuid import UUID, uuid4
|
|
|
|
import requests
|
|
from memoization import cached
|
|
from onefuzztypes.enums import ErrorCode, WebhookMessageState
|
|
from onefuzztypes.events import Event, EventMessage, EventPing, EventType
|
|
from onefuzztypes.models import Error, Result
|
|
from onefuzztypes.webhooks import Webhook as BASE_WEBHOOK
|
|
from onefuzztypes.webhooks import WebhookMessage
|
|
from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .__version__ import __version__
|
|
from .azure.creds import get_instance_id, get_instance_name
|
|
from .azure.queue import queue_object
|
|
from .azure.storage import StorageType
|
|
from .orm import MappingIntStrAny, ORMMixin
|
|
|
|
MAX_TRIES = 5
|
|
EXPIRE_DAYS = 7
|
|
USER_AGENT = "onefuzz-webhook %s" % (__version__)
|
|
|
|
|
|
class WebhookMessageQueueObj(BaseModel):
|
|
webhook_id: UUID
|
|
event_id: UUID
|
|
|
|
|
|
class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
|
|
timestamp: Optional[datetime.datetime] = Field(alias="Timestamp")
|
|
|
|
@classmethod
|
|
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
return ("webhook_id", "event_id")
|
|
|
|
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
|
return {"etag": ..., "timestamp": ...}
|
|
|
|
@classmethod
|
|
def search_expired(cls) -> List["WebhookMessageLog"]:
|
|
expire_time = datetime.datetime.utcnow() - datetime.timedelta(days=EXPIRE_DAYS)
|
|
time_filter = "Timestamp lt datetime'%s'" % expire_time.isoformat()
|
|
return cls.search(raw_unchecked_filter=time_filter)
|
|
|
|
@classmethod
|
|
def process_from_queue(cls, obj: WebhookMessageQueueObj) -> None:
|
|
message = cls.get(obj.webhook_id, obj.event_id)
|
|
if message is None:
|
|
logging.error(
|
|
"webhook message missing. %s:%s", obj.webhook_id, obj.event_id
|
|
)
|
|
return
|
|
message.process()
|
|
|
|
def process(self) -> None:
|
|
if self.state in [WebhookMessageState.failed, WebhookMessageState.succeeded]:
|
|
logging.info(
|
|
"webhook message already handled: %s:%s", self.webhook_id, self.event_id
|
|
)
|
|
return
|
|
|
|
self.try_count += 1
|
|
|
|
logging.debug("sending webhook: %s:%s", self.webhook_id, self.event_id)
|
|
if self.send():
|
|
self.state = WebhookMessageState.succeeded
|
|
self.save()
|
|
logging.info("sent webhook event: %s:%s", self.webhook_id, self.event_id)
|
|
return
|
|
|
|
if self.try_count < MAX_TRIES:
|
|
self.state = WebhookMessageState.retrying
|
|
self.save()
|
|
self.queue_webhook()
|
|
logging.warning(
|
|
"sending webhook event failed, re-queued. %s:%s",
|
|
self.webhook_id,
|
|
self.event_id,
|
|
)
|
|
else:
|
|
self.state = WebhookMessageState.failed
|
|
self.save()
|
|
logging.warning(
|
|
"sending webhook event failed %d times. %s:%s",
|
|
self.try_count,
|
|
self.webhook_id,
|
|
self.event_id,
|
|
)
|
|
|
|
def send(self) -> bool:
|
|
webhook = Webhook.get_by_id(self.webhook_id)
|
|
if isinstance(webhook, Error):
|
|
logging.error(
|
|
"webhook no longer exists: %s:%s", self.webhook_id, self.event_id
|
|
)
|
|
return False
|
|
|
|
try:
|
|
return webhook.send(self)
|
|
except Exception as err:
|
|
logging.error(
|
|
"webhook failed with exception: %s:%s - %s",
|
|
self.webhook_id,
|
|
self.event_id,
|
|
err,
|
|
)
|
|
return False
|
|
|
|
def queue_webhook(self) -> None:
|
|
obj = WebhookMessageQueueObj(webhook_id=self.webhook_id, event_id=self.event_id)
|
|
|
|
if self.state == WebhookMessageState.queued:
|
|
visibility_timeout = 0
|
|
elif self.state == WebhookMessageState.retrying:
|
|
visibility_timeout = 30
|
|
else:
|
|
logging.error(
|
|
"invalid WebhookMessage queue state, not queuing. %s:%s - %s",
|
|
self.webhook_id,
|
|
self.event_id,
|
|
self.state,
|
|
)
|
|
return
|
|
|
|
queue_object(
|
|
"webhooks",
|
|
obj,
|
|
StorageType.config,
|
|
visibility_timeout=visibility_timeout,
|
|
)
|
|
|
|
|
|
class Webhook(BASE_WEBHOOK, ORMMixin):
|
|
@classmethod
|
|
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
return ("webhook_id", "name")
|
|
|
|
@classmethod
|
|
def send_event(cls, event_message: EventMessage) -> None:
|
|
for webhook in get_webhooks_cached():
|
|
if event_message.event_type not in webhook.event_types:
|
|
continue
|
|
|
|
webhook._add_event(event_message)
|
|
|
|
@classmethod
|
|
def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
|
|
webhooks = cls.search(query={"webhook_id": [webhook_id]})
|
|
if not webhooks:
|
|
return Error(
|
|
code=ErrorCode.INVALID_REQUEST, errors=["unable to find webhook"]
|
|
)
|
|
|
|
if len(webhooks) != 1:
|
|
return Error(
|
|
code=ErrorCode.INVALID_REQUEST,
|
|
errors=["error identifying Notification"],
|
|
)
|
|
webhook = webhooks[0]
|
|
return webhook
|
|
|
|
def _add_event(self, event_message: EventMessage) -> None:
|
|
message = WebhookMessageLog(
|
|
webhook_id=self.webhook_id,
|
|
event_id=event_message.event_id,
|
|
event_type=event_message.event_type,
|
|
event=event_message.event,
|
|
instance_id=event_message.instance_id,
|
|
instance_name=event_message.instance_name,
|
|
)
|
|
message.save()
|
|
message.queue_webhook()
|
|
|
|
def ping(self) -> EventPing:
|
|
ping = EventPing(ping_id=uuid4())
|
|
self._add_event(
|
|
EventMessage(
|
|
event_type=EventType.ping,
|
|
event=ping,
|
|
instance_id=get_instance_id(),
|
|
instance_name=get_instance_name(),
|
|
)
|
|
)
|
|
return ping
|
|
|
|
def send(self, message_log: WebhookMessageLog) -> bool:
|
|
if self.url is None:
|
|
raise Exception("webhook URL incorrectly removed: %s" % self.webhook_id)
|
|
|
|
data, digest = build_message(
|
|
webhook_id=self.webhook_id,
|
|
event_id=message_log.event_id,
|
|
event_type=message_log.event_type,
|
|
event=message_log.event,
|
|
secret_token=self.secret_token,
|
|
)
|
|
|
|
headers = {"Content-type": "application/json", "User-Agent": USER_AGENT}
|
|
|
|
if digest:
|
|
headers["X-Onefuzz-Digest"] = digest
|
|
|
|
response = requests.post(
|
|
self.url,
|
|
data=data,
|
|
headers=headers,
|
|
)
|
|
return response.ok
|
|
|
|
|
|
def build_message(
|
|
*,
|
|
webhook_id: UUID,
|
|
event_id: UUID,
|
|
event_type: EventType,
|
|
event: Event,
|
|
secret_token: Optional[str] = None,
|
|
) -> Tuple[bytes, Optional[str]]:
|
|
data = (
|
|
WebhookMessage(
|
|
webhook_id=webhook_id,
|
|
event_id=event_id,
|
|
event_type=event_type,
|
|
event=event,
|
|
instance_id=get_instance_id(),
|
|
instance_name=get_instance_name(),
|
|
)
|
|
.json(sort_keys=True, exclude_none=True)
|
|
.encode()
|
|
)
|
|
digest = None
|
|
if secret_token:
|
|
digest = hmac.new(secret_token.encode(), msg=data, digestmod=sha512).hexdigest()
|
|
return (data, digest)
|
|
|
|
|
|
@cached(ttl=30)
|
|
def get_webhooks_cached() -> List[Webhook]:
|
|
return Webhook.search()
|