Unify Dashboard & Webhook events (#394)

This change unifies the previously adhoc SignalR events and Webhooks into a single event format.
This commit is contained in:
bmc-msft
2021-01-11 16:43:09 -05:00
committed by GitHub
parent 465727680d
commit 513d1f52c9
37 changed files with 2970 additions and 825 deletions

View File

@ -67,13 +67,11 @@ def on_state_update(
# they send 'init' with reimage_requested, it's because the node was reimaged
# successfully.
node.reimage_requested = False
node.state = state
node.save()
node.set_state(state)
return None
logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
node.state = state
node.save()
node.set_state(state)
if state == NodeState.free:
logging.info("node now available for work: %s", machine_id)
@ -113,9 +111,7 @@ def on_state_update(
# Other states we would want to preserve are excluded by the
# outermost conditional check.
if task.state not in [TaskState.running, TaskState.setting_up]:
task.state = TaskState.setting_up
task.save()
task.on_start()
task.set_state(TaskState.setting_up)
# Note: we set the node task state to `setting_up`, even though
# the task itself may be `running`.
@ -160,8 +156,7 @@ def on_worker_event_running(
return node
if node.state not in NodeState.ready_for_reset():
node.state = NodeState.busy
node.save()
node.set_state(NodeState.busy)
node_task = NodeTasks(
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
@ -184,12 +179,7 @@ def on_worker_event_running(
task.job_id,
task.task_id,
)
task.state = TaskState.running
task.save()
# Start the clock for the task if it wasn't started already
# (as happens in 1.0.0 agents)
task.on_start()
task.set_state(TaskState.running)
task_event = TaskEvent(
task_id=task.task_id,

View File

@ -66,7 +66,7 @@ def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
if not autoscale_config.region:
raise Exception("Region is missing")
scaleset = Scaleset.create(
Scaleset.create(
pool_name=pool.name,
vm_sku=autoscale_config.vm_sku,
image=autoscale_config.image,
@ -75,7 +75,6 @@ def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
spot_instances=autoscale_config.spot_instances,
tags={"pool": pool.name},
)
scaleset.save()
nodes_needed -= max_nodes_scaleset

View File

@ -1,54 +0,0 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
from enum import Enum
from queue import Empty, Queue
from typing import Dict, Optional, Union
from uuid import UUID
from onefuzztypes.primitives import Event
EVENTS: Queue = Queue()
def resolve(data: Event) -> Union[str, int, Dict[str, str]]:
if isinstance(data, str):
return data
if isinstance(data, UUID):
return str(data)
elif isinstance(data, Enum):
return data.name
elif isinstance(data, int):
return data
elif isinstance(data, dict):
for x in data:
data[x] = str(data[x])
return data
raise NotImplementedError("no conversion from %s" % type(data))
def get_event() -> Optional[str]:
events = []
for _ in range(10):
try:
(event, data) = EVENTS.get(block=False)
events.append({"type": event, "data": data})
EVENTS.task_done()
except Empty:
break
if events:
return json.dumps({"target": "dashboard", "arguments": events})
else:
return None
def add_event(message_type: str, data: Dict[str, Event]) -> None:
for key in data:
data[key] = resolve(data[key])
EVENTS.put((message_type, data))

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import logging
from queue import Empty, Queue
from typing import Optional
from onefuzztypes.events import Event, EventMessage, get_event_type
from .webhooks import Webhook
EVENTS: Queue = Queue()
def get_events() -> Optional[str]:
events = []
for _ in range(5):
try:
event = EVENTS.get(block=False)
events.append(json.loads(event.json(exclude_none=True)))
EVENTS.task_done()
except Empty:
break
if events:
return json.dumps({"target": "events", "arguments": events})
else:
return None
def send_event(event: Event) -> None:
event_type = get_event_type(event)
logging.info("sending event: %s - %s", event_type, event)
event_message = EventMessage(event_type=event_type, event=event)
EVENTS.put(event_message)
Webhook.send_event(event_message)

View File

@ -8,8 +8,10 @@ from datetime import datetime, timedelta
from typing import List, Optional, Tuple
from onefuzztypes.enums import JobState, TaskState
from onefuzztypes.events import EventJobCreated, EventJobStopped
from onefuzztypes.models import Job as BASE_JOB
from .events import send_event
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
from .tasks.main import Task
@ -37,13 +39,6 @@ class Job(BASE_JOB, ORMMixin):
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"task_info": ...}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"state": ...,
"error": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"machine_id": ...,
@ -70,6 +65,11 @@ class Job(BASE_JOB, ORMMixin):
task.mark_stopping()
else:
self.state = JobState.stopped
send_event(
EventJobStopped(
job_id=self.job_id, config=self.config, user_info=self.user_info
)
)
self.save()
def on_start(self) -> None:
@ -77,3 +77,14 @@ class Job(BASE_JOB, ORMMixin):
if self.end_time is None:
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
self.save()
def save(self, new: bool = False, require_etag: bool = False) -> None:
created = self.etag is None
super().save(new=new, require_etag=require_etag)
if created:
send_event(
EventJobCreated(
job_id=self.job_id, config=self.config, user_info=self.user_info
)
)

View File

@ -10,6 +10,7 @@ from uuid import UUID
from memoization import cached
from onefuzztypes import models
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.events import EventCrashReported, EventFileAdded
from onefuzztypes.models import (
ADOTemplate,
Error,
@ -27,7 +28,7 @@ from ..azure.containers import (
)
from ..azure.queue import send_message
from ..azure.storage import StorageType
from ..dashboard import add_event
from ..events import send_event
from ..orm import ORMMixin
from ..reports import get_report
from ..tasks.config import get_input_container_queues
@ -116,16 +117,16 @@ def new_files(container: Container, filename: str) -> None:
if metadata:
results["metadata"] = metadata
report = get_report(container, filename)
if report:
results["executable"] = report.executable
results["crash_type"] = report.crash_type
results["crash_site"] = report.crash_site
results["job_id"] = report.job_id
results["task_id"] = report.task_id
notifications = get_notifications(container)
if notifications:
report = get_report(container, filename)
if report:
results["executable"] = report.executable
results["crash_type"] = report.crash_type
results["crash_site"] = report.crash_site
results["job_id"] = report.job_id
results["task_id"] = report.task_id
logging.info("notifications for %s %s %s", container, filename, notifications)
done = []
for notification in notifications:
@ -154,4 +155,9 @@ def new_files(container: Container, filename: str) -> None:
)
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
add_event("new_file", results)
if report:
send_event(
EventCrashReported(report=report, container=container, filename=filename)
)
else:
send_event(EventFileAdded(container=container, filename=filename))

View File

@ -39,7 +39,6 @@ from pydantic import BaseModel, Field
from typing_extensions import Protocol
from .azure.table import get_client
from .dashboard import add_event
from .telemetry import track_event_filtered
from .updates import queue_update
@ -255,21 +254,9 @@ class ORMMixin(ModelMixin):
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {}
def event_include(self) -> Optional[MappingIntStrAny]:
return {}
def telemetry(self) -> Any:
return self.raw(exclude_none=True, include=self.telemetry_include())
def _event_as_needed(self) -> None:
# Upon ORM save, if the object returns event data, we'll send it to the
# dashboard event subsystem
data = self.raw(exclude_none=True, include=self.event_include())
if not data:
return
add_event(self.table_name(), data)
def get_keys(self) -> Tuple[KEY, KEY]:
partition_key_field, row_key_field = self.key_fields()
@ -331,13 +318,9 @@ class ORMMixin(ModelMixin):
if telem:
track_event_filtered(TelemetryEvent[self.table_name()], telem)
self._event_as_needed()
return None
def delete(self) -> None:
# fire off an event so Signalr knows it's being deleted
self._event_as_needed()
partition_key, row_key = self.get_keys()
client = get_client()

View File

@ -16,6 +16,16 @@ from onefuzztypes.enums import (
PoolState,
ScalesetState,
)
from onefuzztypes.events import (
EventNodeCreated,
EventNodeDeleted,
EventNodeStateUpdated,
EventPoolCreated,
EventPoolDeleted,
EventScalesetCreated,
EventScalesetDeleted,
EventScalesetFailed,
)
from onefuzztypes.models import AutoScaleConfig, Error
from onefuzztypes.models import Node as BASE_NODE
from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey
@ -60,6 +70,7 @@ from .azure.vmss import (
resize_vmss,
update_extensions,
)
from .events import send_event
from .extension import fuzz_extensions
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
@ -76,6 +87,31 @@ class Node(BASE_NODE, ORMMixin):
# should only be unset during agent_registration POST
reimage_queued: bool = Field(default=False)
@classmethod
def create(
cls,
*,
pool_name: PoolName,
machine_id: UUID,
scaleset_id: Optional[UUID],
version: str,
) -> "Node":
node = cls(
pool_name=pool_name,
machine_id=machine_id,
scaleset_id=scaleset_id,
version=version,
)
node.save()
send_event(
EventNodeCreated(
machine_id=node.machine_id,
scaleset_id=node.scaleset_id,
pool_name=node.pool_name,
)
)
return node
@classmethod
def search_states(
cls,
@ -163,14 +199,6 @@ class Node(BASE_NODE, ORMMixin):
"scaleset_id": ...,
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_name": ...,
"machine_id": ...,
"state": ...,
"scaleset_id": ...,
}
def scaleset_node_exists(self) -> bool:
if self.scaleset_id is None:
return False
@ -302,6 +330,7 @@ class Node(BASE_NODE, ORMMixin):
""" Tell the node to stop everything. """
self.set_shutdown()
self.stop()
self.set_state(NodeState.halt)
@classmethod
def get_dead_nodes(
@ -315,8 +344,29 @@ class Node(BASE_NODE, ORMMixin):
raw_unchecked_filter=time_filter,
)
def set_state(self, state: NodeState) -> None:
if self.state != state:
self.state = state
send_event(
EventNodeStateUpdated(
machine_id=self.machine_id,
pool_name=self.pool_name,
scaleset_id=self.scaleset_id,
state=state,
)
)
self.save()
def delete(self) -> None:
NodeTasks.clear_by_machine_id(self.machine_id)
send_event(
EventNodeDeleted(
machine_id=self.machine_id,
pool_name=self.pool_name,
scaleset_id=self.scaleset_id,
)
)
super().delete()
NodeMessage.clear_messages(self.machine_id)
@ -410,7 +460,7 @@ class Pool(BASE_POOL, ORMMixin):
client_id: Optional[UUID],
autoscale: Optional[AutoScaleConfig],
) -> "Pool":
return cls(
entry = cls(
name=name,
os=os,
arch=arch,
@ -419,6 +469,17 @@ class Pool(BASE_POOL, ORMMixin):
config=None,
autoscale=autoscale,
)
entry.save()
send_event(
EventPoolCreated(
pool_name=name,
os=os,
arch=arch,
managed=managed,
autoscale=autoscale,
)
)
return entry
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {
@ -435,15 +496,6 @@ class Pool(BASE_POOL, ORMMixin):
"timestamp": ...,
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"name": ...,
"pool_id": ...,
"os": ...,
"state": ...,
"managed": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_id": ...,
@ -533,6 +585,17 @@ class Pool(BASE_POOL, ORMMixin):
query["state"] = states
return cls.search(query=query)
def set_shutdown(self, now: bool) -> None:
if self.state in [PoolState.halt, PoolState.shutdown]:
return
if now:
self.state = PoolState.halt
else:
self.state = PoolState.shutdown
self.save()
def shutdown(self) -> None:
""" shutdown allows nodes to finish current work then delete """
scalesets = Scaleset.search_by_pool(self.name)
@ -545,8 +608,7 @@ class Pool(BASE_POOL, ORMMixin):
return
for scaleset in scalesets:
scaleset.state = ScalesetState.shutdown
scaleset.save()
scaleset.set_shutdown(now=False)
for node in nodes:
node.set_shutdown()
@ -555,6 +617,7 @@ class Pool(BASE_POOL, ORMMixin):
def halt(self) -> None:
""" halt the pool immediately """
scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes:
@ -577,21 +640,15 @@ class Pool(BASE_POOL, ORMMixin):
def key_fields(cls) -> Tuple[str, str]:
return ("name", "pool_id")
def delete(self) -> None:
super().delete()
send_event(EventPoolDeleted(pool_name=self.name))
class Scaleset(BASE_SCALESET, ORMMixin):
def save_exclude(self) -> Optional[MappingIntStrAny]:
return {"nodes": ...}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"pool_name": ...,
"scaleset_id": ...,
"state": ...,
"os": ...,
"size": ...,
"error": ...,
}
def telemetry_include(self) -> Optional[MappingIntStrAny]:
return {
"scaleset_id": ...,
@ -615,7 +672,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
client_id: Optional[UUID] = None,
client_object_id: Optional[UUID] = None,
) -> "Scaleset":
return cls(
entry = cls(
pool_name=pool_name,
vm_sku=vm_sku,
image=image,
@ -627,6 +684,18 @@ class Scaleset(BASE_SCALESET, ORMMixin):
client_object_id=client_object_id,
tags=tags,
)
entry.save()
send_event(
EventScalesetCreated(
scaleset_id=entry.scaleset_id,
pool_name=entry.pool_name,
vm_sku=vm_sku,
image=image,
region=region,
size=size,
)
)
return entry
@classmethod
def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]:
@ -651,6 +720,20 @@ class Scaleset(BASE_SCALESET, ORMMixin):
def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]:
return cls.search(query={"client_object_id": [object_id]})
def set_failed(self, error: Error) -> None:
if self.error is not None:
return
self.error = error
self.state = ScalesetState.creation_failed
self.save()
send_event(
EventScalesetFailed(
scaleset_id=self.scaleset_id, pool_name=self.pool_name, error=self.error
)
)
def init(self) -> None:
logging.info("scaleset init: %s", self.scaleset_id)
@ -660,9 +743,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
# scaleset being added to the pool.
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.error = pool
self.state = ScalesetState.halt
self.save()
self.set_failed(pool)
return
if pool.state == PoolState.init:
@ -672,14 +753,16 @@ class Scaleset(BASE_SCALESET, ORMMixin):
elif pool.state == PoolState.running:
image_os = get_os(self.region, self.image)
if isinstance(image_os, Error):
self.error = image_os
self.state = ScalesetState.creation_failed
self.set_failed(image_os)
return
elif image_os != pool.os:
self.error = Error(
error = Error(
code=ErrorCode.INVALID_REQUEST,
errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)],
)
self.state = ScalesetState.creation_failed
self.set_failed(error)
return
else:
self.state = ScalesetState.setup
else:
@ -698,26 +781,23 @@ class Scaleset(BASE_SCALESET, ORMMixin):
logging.info("creating network: %s", self.region)
result = network.create()
if isinstance(result, Error):
self.error = result
self.state = ScalesetState.creation_failed
self.set_failed(result)
return
self.save()
return
if self.auth is None:
self.error = Error(
error = Error(
code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"]
)
self.state = ScalesetState.creation_failed
self.save()
self.set_failed(error)
return
vmss = get_vmss(self.scaleset_id)
if vmss is None:
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.error = pool
self.state = ScalesetState.halt
self.save()
self.set_failed(pool)
return
logging.info("creating scaleset: %s", self.scaleset_id)
@ -736,13 +816,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.tags,
)
if isinstance(result, Error):
self.error = result
logging.error(
"stopping task because of failed vmss: %s %s",
self.scaleset_id,
result,
)
self.state = ScalesetState.creation_failed
self.set_failed(result)
return
else:
logging.info("creating scaleset: %s", self.scaleset_id)
elif vmss.provisioning_state == "Creating":
@ -750,10 +825,10 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.try_set_identity(vmss)
else:
logging.info("scaleset running: %s", self.scaleset_id)
error = self.try_set_identity(vmss)
if error:
self.state = ScalesetState.creation_failed
self.error = error
identity_result = self.try_set_identity(vmss)
if identity_result:
self.set_failed(identity_result)
return
else:
self.state = ScalesetState.running
self.save()
@ -843,8 +918,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.delete_nodes(to_delete)
for node in to_delete:
node.set_halt()
node.state = NodeState.halt
node.save()
if to_reimage:
self.reimage_nodes(to_reimage)
@ -967,6 +1040,17 @@ class Scaleset(BASE_SCALESET, ORMMixin):
node.reimage_queued = True
node.save()
def set_shutdown(self, now: bool) -> None:
if self.state in [ScalesetState.halt, ScalesetState.shutdown]:
return
if now:
self.state = ScalesetState.halt
else:
self.state = ScalesetState.shutdown
self.save()
def shutdown(self) -> None:
size = get_vmss_size(self.scaleset_id)
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
@ -977,7 +1061,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
self.halt()
def halt(self) -> None:
self.state = ScalesetState.halt
ScalesetShrinkQueue(self.scaleset_id).delete()
for node in Node.search_states(scaleset_id=self.scaleset_id):
@ -1050,8 +1133,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
pool = Pool.get_by_name(self.pool_name)
if isinstance(pool, Error):
self.error = pool
self.halt()
self.set_failed(pool)
return
logging.debug("updating scaleset configs: %s", self.scaleset_id)
@ -1068,6 +1150,12 @@ class Scaleset(BASE_SCALESET, ORMMixin):
def key_fields(cls) -> Tuple[str, str]:
return ("pool_name", "scaleset_id")
def delete(self) -> None:
super().delete()
send_event(
EventScalesetDeleted(scaleset_id=self.scaleset_id, pool_name=self.pool_name)
)
class ShrinkEntry(BaseModel):
shrink_id: UUID = Field(default_factory=uuid4)

View File

@ -8,7 +8,8 @@ import logging
from typing import List, Optional, Tuple
from azure.mgmt.compute.models import VirtualMachine
from onefuzztypes.enums import VmState
from onefuzztypes.enums import ErrorCode, VmState
from onefuzztypes.events import EventProxyCreated, EventProxyDeleted, EventProxyFailed
from onefuzztypes.models import (
Authentication,
Error,
@ -26,8 +27,9 @@ from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas
from .azure.storage import StorageType
from .azure.vm import VM
from .events import send_event
from .extension import proxy_manager_extensions
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
from .orm import ORMMixin, QueryFilter
from .proxy_forward import ProxyForward
PROXY_SKU = "Standard_B2s"
@ -41,7 +43,7 @@ class Proxy(ORMMixin):
state: VmState = Field(default=VmState.init)
auth: Authentication = Field(default_factory=build_auth)
ip: Optional[str]
error: Optional[str]
error: Optional[Error]
version: str = Field(default=__version__)
heartbeat: Optional[ProxyHeartbeat]
@ -49,14 +51,6 @@ class Proxy(ORMMixin):
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("region", None)
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"region": ...,
"state": ...,
"ip": ...,
"error": ...,
}
def get_vm(self) -> VM:
vm = VM(
name="proxy-%s" % self.region,
@ -72,42 +66,59 @@ class Proxy(ORMMixin):
vm_data = vm.get()
if vm_data:
if vm_data.provisioning_state == "Failed":
self.set_failed(vm)
self.set_provision_failed(vm_data)
return
else:
self.save_proxy_config()
self.state = VmState.extensions_launch
else:
result = vm.create()
if isinstance(result, Error):
self.error = repr(result)
self.state = VmState.stopping
self.set_failed(result)
return
self.save()
def set_failed(self, vm_data: VirtualMachine) -> None:
logging.error("vm failed to provision: %s", vm_data.name)
def set_provision_failed(self, vm_data: VirtualMachine) -> None:
errors = ["provisioning failed"]
for status in vm_data.instance_view.statuses:
if status.level.name.lower() == "error":
logging.error(
"vm status: %s %s %s %s",
vm_data.name,
status.code,
status.display_status,
status.message,
errors.append(
f"code:{status.code} status:{status.display_status} "
f"message:{status.message}"
)
self.state = VmState.vm_allocation_failed
self.set_failed(
Error(
code=ErrorCode.PROXY_FAILED,
errors=errors,
)
)
return
def set_failed(self, error: Error) -> None:
if self.error is not None:
return
logging.error("proxy vm failed: %s - %s", self.region, error)
send_event(EventProxyFailed(region=self.region, error=error))
self.error = error
self.state = VmState.stopping
self.save()
def extensions_launch(self) -> None:
vm = self.get_vm()
vm_data = vm.get()
if not vm_data:
logging.error("Azure VM does not exist: %s", vm.name)
self.state = VmState.stopping
self.save()
self.set_failed(
Error(
code=ErrorCode.PROXY_FAILED,
errors=["azure not able to find vm"],
)
)
return
if vm_data.provisioning_state == "Failed":
self.set_failed(vm_data)
self.save()
self.set_provision_failed(vm_data)
return
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
@ -119,9 +130,8 @@ class Proxy(ORMMixin):
extensions = proxy_manager_extensions(self.region)
result = vm.add_extensions(extensions)
if isinstance(result, Error):
logging.error("vm extension failed: %s", repr(result))
self.error = repr(result)
self.state = VmState.stopping
self.set_failed(result)
return
elif result:
self.state = VmState.running
@ -231,4 +241,9 @@ class Proxy(ORMMixin):
proxy = Proxy(region=region)
proxy.save()
send_event(EventProxyCreated(region=region))
return proxy
def delete(self) -> None:
super().delete()
send_event(EventProxyDeleted(region=self.region))

View File

@ -7,6 +7,7 @@ import json
import logging
from typing import Optional, Union
from memoization import cached
from onefuzztypes.models import Report
from onefuzztypes.primitives import Container
from pydantic import ValidationError
@ -16,7 +17,7 @@ from .azure.storage import StorageType
def parse_report(
content: Union[str, bytes], metadata: Optional[str] = None
content: Union[str, bytes], file_path: Optional[str] = None
) -> Optional[Report]:
if isinstance(content, bytes):
try:
@ -24,7 +25,7 @@ def parse_report(
except UnicodeDecodeError as err:
logging.error(
"unable to parse report (%s): unicode decode of report failed - %s",
metadata,
file_path,
err,
)
return None
@ -33,28 +34,30 @@ def parse_report(
data = json.loads(content)
except json.decoder.JSONDecodeError as err:
logging.error(
"unable to parse report (%s): json decoding failed - %s", metadata, err
"unable to parse report (%s): json decoding failed - %s", file_path, err
)
return None
try:
entry = Report.parse_obj(data)
except ValidationError as err:
logging.error("unable to parse report (%s): %s", metadata, err)
logging.error("unable to parse report (%s): %s", file_path, err)
return None
return entry
# cache the last 1000 reports
@cached(max_size=1000)
def get_report(container: Container, filename: str) -> Optional[Report]:
metadata = "/".join([container, filename])
file_path = "/".join([container, filename])
if not filename.endswith(".json"):
logging.error("get_report invalid extension: %s", metadata)
logging.error("get_report invalid extension: %s", file_path)
return None
blob = get_blob(container, filename, StorageType.corpus)
if blob is None:
logging.error("get_report invalid blob: %s", metadata)
logging.error("get_report invalid blob: %s", file_path)
return None
return parse_report(blob, metadata=metadata)
return parse_report(blob, file_path=file_path)

View File

@ -9,22 +9,23 @@ from typing import List, Optional, Tuple, Union
from uuid import UUID
from onefuzztypes.enums import ErrorCode, TaskState
from onefuzztypes.events import (
EventTaskCreated,
EventTaskFailed,
EventTaskStateUpdated,
EventTaskStopped,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Task as BASE_TASK
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
from onefuzztypes.webhooks import (
WebhookEventTaskCreated,
WebhookEventTaskFailed,
WebhookEventTaskStopped,
)
from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue
from ..azure.storage import StorageType
from ..events import send_event
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..pools import Node, Pool, Scaleset
from ..proxy_forward import ProxyForward
from ..webhooks import Webhook
class Task(BASE_TASK, ORMMixin):
@ -58,8 +59,8 @@ class Task(BASE_TASK, ORMMixin):
raise Exception("task must have vm or pool")
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
task.save()
Webhook.send_event(
WebhookEventTaskCreated(
send_event(
EventTaskCreated(
job_id=task.job_id,
task_id=task.task_id,
config=config,
@ -116,18 +117,9 @@ class Task(BASE_TASK, ORMMixin):
"config": {"vm": {"count": ...}, "task": {"type": ...}},
}
def event_include(self) -> Optional[MappingIntStrAny]:
return {
"job_id": ...,
"task_id": ...,
"state": ...,
"error": ...,
}
def init(self) -> None:
create_queue(self.task_id, StorageType.corpus)
self.state = TaskState.waiting
self.save()
self.set_state(TaskState.waiting)
def stopping(self) -> None:
# TODO: we need to 'unschedule' this task from the existing pools
@ -136,8 +128,7 @@ class Task(BASE_TASK, ORMMixin):
ProxyForward.remove_forward(self.task_id)
delete_queue(str(self.task_id), StorageType.corpus)
Node.stop_task(self.task_id)
self.state = TaskState.stopped
self.save()
self.set_state(TaskState.stopped, send=False)
@classmethod
def search_states(
@ -195,10 +186,9 @@ class Task(BASE_TASK, ORMMixin):
)
return
self.state = TaskState.stopping
self.save()
Webhook.send_event(
WebhookEventTaskStopped(
self.set_state(TaskState.stopping, send=False)
send_event(
EventTaskStopped(
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
)
)
@ -211,11 +201,10 @@ class Task(BASE_TASK, ORMMixin):
return
self.error = error
self.state = TaskState.stopping
self.save()
self.set_state(TaskState.stopping, send=False)
Webhook.send_event(
WebhookEventTaskFailed(
send_event(
EventTaskFailed(
job_id=self.job_id,
task_id=self.task_id,
error=error,
@ -287,7 +276,6 @@ class Task(BASE_TASK, ORMMixin):
self.end_time = datetime.utcnow() + timedelta(
hours=self.config.task.duration
)
self.save()
from ..jobs import Job
@ -298,3 +286,22 @@ class Task(BASE_TASK, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, str]:
return ("job_id", "task_id")
def set_state(self, state: TaskState, send: bool = True) -> None:
if self.state == state:
return
self.state = state
if self.state in [TaskState.running, TaskState.setting_up]:
self.on_start()
self.save()
send_event(
EventTaskStateUpdated(
job_id=self.job_id,
task_id=self.task_id,
state=self.state,
end_time=self.end_time,
)
)

View File

@ -235,8 +235,7 @@ def schedule_tasks() -> None:
if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
for work_unit in work_set.work_units:
task = tasks_by_id[work_unit.task_id]
task.state = TaskState.scheduled
task.save()
task.set_state(TaskState.scheduled)
seen.add(task.task_id)
not_ready_count = len(tasks) - len(seen)

View File

@ -8,21 +8,15 @@ import hmac
import logging
from hashlib import sha512
from typing import List, Optional, Tuple
from uuid import UUID
from uuid import UUID, uuid4
import requests
from memoization import cached
from onefuzztypes.enums import ErrorCode, WebhookEventType, WebhookMessageState
from onefuzztypes.enums import ErrorCode, WebhookMessageState
from onefuzztypes.events import Event, EventMessage, EventPing, EventType
from onefuzztypes.models import Error, Result
from onefuzztypes.webhooks import Webhook as BASE_WEBHOOK
from onefuzztypes.webhooks import (
WebhookEvent,
WebhookEventPing,
WebhookEventTaskCreated,
WebhookEventTaskFailed,
WebhookEventTaskStopped,
WebhookMessage,
)
from onefuzztypes.webhooks import WebhookMessage
from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
from pydantic import BaseModel
@ -140,34 +134,18 @@ class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
)
def get_event_type(event: WebhookEvent) -> WebhookEventType:
events = {
WebhookEventTaskCreated: WebhookEventType.task_created,
WebhookEventTaskFailed: WebhookEventType.task_failed,
WebhookEventTaskStopped: WebhookEventType.task_stopped,
WebhookEventPing: WebhookEventType.ping,
}
for event_class in events:
if isinstance(event, event_class):
return events[event_class]
raise NotImplementedError("unsupported event type: %s" % event)
class Webhook(BASE_WEBHOOK, ORMMixin):
@classmethod
def key_fields(cls) -> Tuple[str, Optional[str]]:
return ("webhook_id", "name")
@classmethod
def send_event(cls, event: WebhookEvent) -> None:
event_type = get_event_type(event)
def send_event(cls, event_message: EventMessage) -> None:
for webhook in get_webhooks_cached():
if event_type not in webhook.event_types:
if event_message.event_type not in webhook.event_types:
continue
webhook._add_event(event_type, event)
webhook._add_event(event_message)
@classmethod
def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
@ -185,18 +163,19 @@ class Webhook(BASE_WEBHOOK, ORMMixin):
webhook = webhooks[0]
return webhook
def _add_event(self, event_type: WebhookEventType, event: WebhookEvent) -> None:
def _add_event(self, event_message: EventMessage) -> None:
message = WebhookMessageLog(
webhook_id=self.webhook_id,
event_type=event_type,
event=event,
event_id=event_message.event_id,
event_type=event_message.event_type,
event=event_message.event,
)
message.save()
message.queue_webhook()
def ping(self) -> WebhookEventPing:
ping = WebhookEventPing()
self._add_event(WebhookEventType.ping, ping)
def ping(self) -> EventPing:
ping = EventPing(ping_id=uuid4())
self._add_event(EventMessage(event_type=EventType.ping, event=ping))
return ping
def send(self, message_log: WebhookMessageLog) -> bool:
@ -228,8 +207,8 @@ def build_message(
*,
webhook_id: UUID,
event_id: UUID,
event_type: WebhookEventType,
event: WebhookEvent,
event_type: EventType,
event: Event,
secret_token: Optional[str] = None,
) -> Tuple[bytes, Optional[str]]:
data = (