mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 03:48:09 +00:00
Unify Dashboard & Webhook events (#394)
This change unifies the previously adhoc SignalR events and Webhooks into a single event format.
This commit is contained in:
@ -67,13 +67,11 @@ def on_state_update(
|
||||
# they send 'init' with reimage_requested, it's because the node was reimaged
|
||||
# successfully.
|
||||
node.reimage_requested = False
|
||||
node.state = state
|
||||
node.save()
|
||||
node.set_state(state)
|
||||
return None
|
||||
|
||||
logging.info("node state update: %s from:%s to:%s", machine_id, node.state, state)
|
||||
node.state = state
|
||||
node.save()
|
||||
node.set_state(state)
|
||||
|
||||
if state == NodeState.free:
|
||||
logging.info("node now available for work: %s", machine_id)
|
||||
@ -113,9 +111,7 @@ def on_state_update(
|
||||
# Other states we would want to preserve are excluded by the
|
||||
# outermost conditional check.
|
||||
if task.state not in [TaskState.running, TaskState.setting_up]:
|
||||
task.state = TaskState.setting_up
|
||||
task.save()
|
||||
task.on_start()
|
||||
task.set_state(TaskState.setting_up)
|
||||
|
||||
# Note: we set the node task state to `setting_up`, even though
|
||||
# the task itself may be `running`.
|
||||
@ -160,8 +156,7 @@ def on_worker_event_running(
|
||||
return node
|
||||
|
||||
if node.state not in NodeState.ready_for_reset():
|
||||
node.state = NodeState.busy
|
||||
node.save()
|
||||
node.set_state(NodeState.busy)
|
||||
|
||||
node_task = NodeTasks(
|
||||
machine_id=machine_id, task_id=event.task_id, state=NodeTaskState.running
|
||||
@ -184,12 +179,7 @@ def on_worker_event_running(
|
||||
task.job_id,
|
||||
task.task_id,
|
||||
)
|
||||
task.state = TaskState.running
|
||||
task.save()
|
||||
|
||||
# Start the clock for the task if it wasn't started already
|
||||
# (as happens in 1.0.0 agents)
|
||||
task.on_start()
|
||||
task.set_state(TaskState.running)
|
||||
|
||||
task_event = TaskEvent(
|
||||
task_id=task.task_id,
|
||||
|
@ -66,7 +66,7 @@ def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
|
||||
if not autoscale_config.region:
|
||||
raise Exception("Region is missing")
|
||||
|
||||
scaleset = Scaleset.create(
|
||||
Scaleset.create(
|
||||
pool_name=pool.name,
|
||||
vm_sku=autoscale_config.vm_sku,
|
||||
image=autoscale_config.image,
|
||||
@ -75,7 +75,6 @@ def scale_up(pool: Pool, scalesets: List[Scaleset], nodes_needed: int) -> None:
|
||||
spot_instances=autoscale_config.spot_instances,
|
||||
tags={"pool": pool.name},
|
||||
)
|
||||
scaleset.save()
|
||||
nodes_needed -= max_nodes_scaleset
|
||||
|
||||
|
||||
|
@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from queue import Empty, Queue
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.primitives import Event
|
||||
|
||||
EVENTS: Queue = Queue()
|
||||
|
||||
|
||||
def resolve(data: Event) -> Union[str, int, Dict[str, str]]:
|
||||
if isinstance(data, str):
|
||||
return data
|
||||
if isinstance(data, UUID):
|
||||
return str(data)
|
||||
elif isinstance(data, Enum):
|
||||
return data.name
|
||||
elif isinstance(data, int):
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
for x in data:
|
||||
data[x] = str(data[x])
|
||||
return data
|
||||
raise NotImplementedError("no conversion from %s" % type(data))
|
||||
|
||||
|
||||
def get_event() -> Optional[str]:
|
||||
events = []
|
||||
|
||||
for _ in range(10):
|
||||
try:
|
||||
(event, data) = EVENTS.get(block=False)
|
||||
events.append({"type": event, "data": data})
|
||||
EVENTS.task_done()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
if events:
|
||||
return json.dumps({"target": "dashboard", "arguments": events})
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def add_event(message_type: str, data: Dict[str, Event]) -> None:
|
||||
for key in data:
|
||||
data[key] = resolve(data[key])
|
||||
|
||||
EVENTS.put((message_type, data))
|
40
src/api-service/__app__/onefuzzlib/events.py
Normal file
40
src/api-service/__app__/onefuzzlib/events.py
Normal file
@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from queue import Empty, Queue
|
||||
from typing import Optional
|
||||
|
||||
from onefuzztypes.events import Event, EventMessage, get_event_type
|
||||
|
||||
from .webhooks import Webhook
|
||||
|
||||
EVENTS: Queue = Queue()
|
||||
|
||||
|
||||
def get_events() -> Optional[str]:
|
||||
events = []
|
||||
|
||||
for _ in range(5):
|
||||
try:
|
||||
event = EVENTS.get(block=False)
|
||||
events.append(json.loads(event.json(exclude_none=True)))
|
||||
EVENTS.task_done()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
if events:
|
||||
return json.dumps({"target": "events", "arguments": events})
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def send_event(event: Event) -> None:
|
||||
event_type = get_event_type(event)
|
||||
logging.info("sending event: %s - %s", event_type, event)
|
||||
event_message = EventMessage(event_type=event_type, event=event)
|
||||
EVENTS.put(event_message)
|
||||
Webhook.send_event(event_message)
|
@ -8,8 +8,10 @@ from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from onefuzztypes.enums import JobState, TaskState
|
||||
from onefuzztypes.events import EventJobCreated, EventJobStopped
|
||||
from onefuzztypes.models import Job as BASE_JOB
|
||||
|
||||
from .events import send_event
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .tasks.main import Task
|
||||
|
||||
@ -37,13 +39,6 @@ class Job(BASE_JOB, ORMMixin):
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"task_info": ...}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"state": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"machine_id": ...,
|
||||
@ -70,6 +65,11 @@ class Job(BASE_JOB, ORMMixin):
|
||||
task.mark_stopping()
|
||||
else:
|
||||
self.state = JobState.stopped
|
||||
send_event(
|
||||
EventJobStopped(
|
||||
job_id=self.job_id, config=self.config, user_info=self.user_info
|
||||
)
|
||||
)
|
||||
self.save()
|
||||
|
||||
def on_start(self) -> None:
|
||||
@ -77,3 +77,14 @@ class Job(BASE_JOB, ORMMixin):
|
||||
if self.end_time is None:
|
||||
self.end_time = datetime.utcnow() + timedelta(hours=self.config.duration)
|
||||
self.save()
|
||||
|
||||
def save(self, new: bool = False, require_etag: bool = False) -> None:
|
||||
created = self.etag is None
|
||||
super().save(new=new, require_etag=require_etag)
|
||||
|
||||
if created:
|
||||
send_event(
|
||||
EventJobCreated(
|
||||
job_id=self.job_id, config=self.config, user_info=self.user_info
|
||||
)
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from uuid import UUID
|
||||
from memoization import cached
|
||||
from onefuzztypes import models
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.events import EventCrashReported, EventFileAdded
|
||||
from onefuzztypes.models import (
|
||||
ADOTemplate,
|
||||
Error,
|
||||
@ -27,7 +28,7 @@ from ..azure.containers import (
|
||||
)
|
||||
from ..azure.queue import send_message
|
||||
from ..azure.storage import StorageType
|
||||
from ..dashboard import add_event
|
||||
from ..events import send_event
|
||||
from ..orm import ORMMixin
|
||||
from ..reports import get_report
|
||||
from ..tasks.config import get_input_container_queues
|
||||
@ -116,16 +117,16 @@ def new_files(container: Container, filename: str) -> None:
|
||||
if metadata:
|
||||
results["metadata"] = metadata
|
||||
|
||||
report = get_report(container, filename)
|
||||
if report:
|
||||
results["executable"] = report.executable
|
||||
results["crash_type"] = report.crash_type
|
||||
results["crash_site"] = report.crash_site
|
||||
results["job_id"] = report.job_id
|
||||
results["task_id"] = report.task_id
|
||||
|
||||
notifications = get_notifications(container)
|
||||
if notifications:
|
||||
report = get_report(container, filename)
|
||||
if report:
|
||||
results["executable"] = report.executable
|
||||
results["crash_type"] = report.crash_type
|
||||
results["crash_site"] = report.crash_site
|
||||
results["job_id"] = report.job_id
|
||||
results["task_id"] = report.task_id
|
||||
|
||||
logging.info("notifications for %s %s %s", container, filename, notifications)
|
||||
done = []
|
||||
for notification in notifications:
|
||||
@ -154,4 +155,9 @@ def new_files(container: Container, filename: str) -> None:
|
||||
)
|
||||
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
|
||||
|
||||
add_event("new_file", results)
|
||||
if report:
|
||||
send_event(
|
||||
EventCrashReported(report=report, container=container, filename=filename)
|
||||
)
|
||||
else:
|
||||
send_event(EventFileAdded(container=container, filename=filename))
|
||||
|
@ -39,7 +39,6 @@ from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from .azure.table import get_client
|
||||
from .dashboard import add_event
|
||||
from .telemetry import track_event_filtered
|
||||
from .updates import queue_update
|
||||
|
||||
@ -255,21 +254,9 @@ class ORMMixin(ModelMixin):
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {}
|
||||
|
||||
def telemetry(self) -> Any:
|
||||
return self.raw(exclude_none=True, include=self.telemetry_include())
|
||||
|
||||
def _event_as_needed(self) -> None:
|
||||
# Upon ORM save, if the object returns event data, we'll send it to the
|
||||
# dashboard event subsystem
|
||||
|
||||
data = self.raw(exclude_none=True, include=self.event_include())
|
||||
if not data:
|
||||
return
|
||||
add_event(self.table_name(), data)
|
||||
|
||||
def get_keys(self) -> Tuple[KEY, KEY]:
|
||||
partition_key_field, row_key_field = self.key_fields()
|
||||
|
||||
@ -331,13 +318,9 @@ class ORMMixin(ModelMixin):
|
||||
if telem:
|
||||
track_event_filtered(TelemetryEvent[self.table_name()], telem)
|
||||
|
||||
self._event_as_needed()
|
||||
return None
|
||||
|
||||
def delete(self) -> None:
|
||||
# fire off an event so Signalr knows it's being deleted
|
||||
self._event_as_needed()
|
||||
|
||||
partition_key, row_key = self.get_keys()
|
||||
|
||||
client = get_client()
|
||||
|
@ -16,6 +16,16 @@ from onefuzztypes.enums import (
|
||||
PoolState,
|
||||
ScalesetState,
|
||||
)
|
||||
from onefuzztypes.events import (
|
||||
EventNodeCreated,
|
||||
EventNodeDeleted,
|
||||
EventNodeStateUpdated,
|
||||
EventPoolCreated,
|
||||
EventPoolDeleted,
|
||||
EventScalesetCreated,
|
||||
EventScalesetDeleted,
|
||||
EventScalesetFailed,
|
||||
)
|
||||
from onefuzztypes.models import AutoScaleConfig, Error
|
||||
from onefuzztypes.models import Node as BASE_NODE
|
||||
from onefuzztypes.models import NodeAssignment, NodeCommand, NodeCommandAddSshKey
|
||||
@ -60,6 +70,7 @@ from .azure.vmss import (
|
||||
resize_vmss,
|
||||
update_extensions,
|
||||
)
|
||||
from .events import send_event
|
||||
from .extension import fuzz_extensions
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
|
||||
@ -76,6 +87,31 @@ class Node(BASE_NODE, ORMMixin):
|
||||
# should only be unset during agent_registration POST
|
||||
reimage_queued: bool = Field(default=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
pool_name: PoolName,
|
||||
machine_id: UUID,
|
||||
scaleset_id: Optional[UUID],
|
||||
version: str,
|
||||
) -> "Node":
|
||||
node = cls(
|
||||
pool_name=pool_name,
|
||||
machine_id=machine_id,
|
||||
scaleset_id=scaleset_id,
|
||||
version=version,
|
||||
)
|
||||
node.save()
|
||||
send_event(
|
||||
EventNodeCreated(
|
||||
machine_id=node.machine_id,
|
||||
scaleset_id=node.scaleset_id,
|
||||
pool_name=node.pool_name,
|
||||
)
|
||||
)
|
||||
return node
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
cls,
|
||||
@ -163,14 +199,6 @@ class Node(BASE_NODE, ORMMixin):
|
||||
"scaleset_id": ...,
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"pool_name": ...,
|
||||
"machine_id": ...,
|
||||
"state": ...,
|
||||
"scaleset_id": ...,
|
||||
}
|
||||
|
||||
def scaleset_node_exists(self) -> bool:
|
||||
if self.scaleset_id is None:
|
||||
return False
|
||||
@ -302,6 +330,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
""" Tell the node to stop everything. """
|
||||
self.set_shutdown()
|
||||
self.stop()
|
||||
self.set_state(NodeState.halt)
|
||||
|
||||
@classmethod
|
||||
def get_dead_nodes(
|
||||
@ -315,8 +344,29 @@ class Node(BASE_NODE, ORMMixin):
|
||||
raw_unchecked_filter=time_filter,
|
||||
)
|
||||
|
||||
def set_state(self, state: NodeState) -> None:
|
||||
if self.state != state:
|
||||
self.state = state
|
||||
send_event(
|
||||
EventNodeStateUpdated(
|
||||
machine_id=self.machine_id,
|
||||
pool_name=self.pool_name,
|
||||
scaleset_id=self.scaleset_id,
|
||||
state=state,
|
||||
)
|
||||
)
|
||||
|
||||
self.save()
|
||||
|
||||
def delete(self) -> None:
|
||||
NodeTasks.clear_by_machine_id(self.machine_id)
|
||||
send_event(
|
||||
EventNodeDeleted(
|
||||
machine_id=self.machine_id,
|
||||
pool_name=self.pool_name,
|
||||
scaleset_id=self.scaleset_id,
|
||||
)
|
||||
)
|
||||
super().delete()
|
||||
NodeMessage.clear_messages(self.machine_id)
|
||||
|
||||
@ -410,7 +460,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
client_id: Optional[UUID],
|
||||
autoscale: Optional[AutoScaleConfig],
|
||||
) -> "Pool":
|
||||
return cls(
|
||||
entry = cls(
|
||||
name=name,
|
||||
os=os,
|
||||
arch=arch,
|
||||
@ -419,6 +469,17 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
config=None,
|
||||
autoscale=autoscale,
|
||||
)
|
||||
entry.save()
|
||||
send_event(
|
||||
EventPoolCreated(
|
||||
pool_name=name,
|
||||
os=os,
|
||||
arch=arch,
|
||||
managed=managed,
|
||||
autoscale=autoscale,
|
||||
)
|
||||
)
|
||||
return entry
|
||||
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
@ -435,15 +496,6 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
"timestamp": ...,
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"name": ...,
|
||||
"pool_id": ...,
|
||||
"os": ...,
|
||||
"state": ...,
|
||||
"managed": ...,
|
||||
}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"pool_id": ...,
|
||||
@ -533,6 +585,17 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
query["state"] = states
|
||||
return cls.search(query=query)
|
||||
|
||||
def set_shutdown(self, now: bool) -> None:
|
||||
if self.state in [PoolState.halt, PoolState.shutdown]:
|
||||
return
|
||||
|
||||
if now:
|
||||
self.state = PoolState.halt
|
||||
else:
|
||||
self.state = PoolState.shutdown
|
||||
|
||||
self.save()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
""" shutdown allows nodes to finish current work then delete """
|
||||
scalesets = Scaleset.search_by_pool(self.name)
|
||||
@ -545,8 +608,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
return
|
||||
|
||||
for scaleset in scalesets:
|
||||
scaleset.state = ScalesetState.shutdown
|
||||
scaleset.save()
|
||||
scaleset.set_shutdown(now=False)
|
||||
|
||||
for node in nodes:
|
||||
node.set_shutdown()
|
||||
@ -555,6 +617,7 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
|
||||
def halt(self) -> None:
|
||||
""" halt the pool immediately """
|
||||
|
||||
scalesets = Scaleset.search_by_pool(self.name)
|
||||
nodes = Node.search(query={"pool_name": [self.name]})
|
||||
if not scalesets and not nodes:
|
||||
@ -577,21 +640,15 @@ class Pool(BASE_POOL, ORMMixin):
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("name", "pool_id")
|
||||
|
||||
def delete(self) -> None:
|
||||
super().delete()
|
||||
send_event(EventPoolDeleted(pool_name=self.name))
|
||||
|
||||
|
||||
class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
||||
return {"nodes": ...}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"pool_name": ...,
|
||||
"scaleset_id": ...,
|
||||
"state": ...,
|
||||
"os": ...,
|
||||
"size": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"scaleset_id": ...,
|
||||
@ -615,7 +672,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
client_id: Optional[UUID] = None,
|
||||
client_object_id: Optional[UUID] = None,
|
||||
) -> "Scaleset":
|
||||
return cls(
|
||||
entry = cls(
|
||||
pool_name=pool_name,
|
||||
vm_sku=vm_sku,
|
||||
image=image,
|
||||
@ -627,6 +684,18 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
client_object_id=client_object_id,
|
||||
tags=tags,
|
||||
)
|
||||
entry.save()
|
||||
send_event(
|
||||
EventScalesetCreated(
|
||||
scaleset_id=entry.scaleset_id,
|
||||
pool_name=entry.pool_name,
|
||||
vm_sku=vm_sku,
|
||||
image=image,
|
||||
region=region,
|
||||
size=size,
|
||||
)
|
||||
)
|
||||
return entry
|
||||
|
||||
@classmethod
|
||||
def search_by_pool(cls, pool_name: PoolName) -> List["Scaleset"]:
|
||||
@ -651,6 +720,20 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
def get_by_object_id(cls, object_id: UUID) -> List["Scaleset"]:
|
||||
return cls.search(query={"client_object_id": [object_id]})
|
||||
|
||||
def set_failed(self, error: Error) -> None:
|
||||
if self.error is not None:
|
||||
return
|
||||
|
||||
self.error = error
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.save()
|
||||
|
||||
send_event(
|
||||
EventScalesetFailed(
|
||||
scaleset_id=self.scaleset_id, pool_name=self.pool_name, error=self.error
|
||||
)
|
||||
)
|
||||
|
||||
def init(self) -> None:
|
||||
logging.info("scaleset init: %s", self.scaleset_id)
|
||||
|
||||
@ -660,9 +743,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
# scaleset being added to the pool.
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
self.error = pool
|
||||
self.state = ScalesetState.halt
|
||||
self.save()
|
||||
self.set_failed(pool)
|
||||
return
|
||||
|
||||
if pool.state == PoolState.init:
|
||||
@ -672,14 +753,16 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
elif pool.state == PoolState.running:
|
||||
image_os = get_os(self.region, self.image)
|
||||
if isinstance(image_os, Error):
|
||||
self.error = image_os
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.set_failed(image_os)
|
||||
return
|
||||
|
||||
elif image_os != pool.os:
|
||||
self.error = Error(
|
||||
error = Error(
|
||||
code=ErrorCode.INVALID_REQUEST,
|
||||
errors=["invalid os (got: %s needed: %s)" % (image_os, pool.os)],
|
||||
)
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.set_failed(error)
|
||||
return
|
||||
else:
|
||||
self.state = ScalesetState.setup
|
||||
else:
|
||||
@ -698,26 +781,23 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
logging.info("creating network: %s", self.region)
|
||||
result = network.create()
|
||||
if isinstance(result, Error):
|
||||
self.error = result
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.set_failed(result)
|
||||
return
|
||||
self.save()
|
||||
return
|
||||
|
||||
if self.auth is None:
|
||||
self.error = Error(
|
||||
error = Error(
|
||||
code=ErrorCode.UNABLE_TO_CREATE, errors=["missing required auth"]
|
||||
)
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.save()
|
||||
self.set_failed(error)
|
||||
return
|
||||
|
||||
vmss = get_vmss(self.scaleset_id)
|
||||
if vmss is None:
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
self.error = pool
|
||||
self.state = ScalesetState.halt
|
||||
self.save()
|
||||
self.set_failed(pool)
|
||||
return
|
||||
|
||||
logging.info("creating scaleset: %s", self.scaleset_id)
|
||||
@ -736,13 +816,8 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.tags,
|
||||
)
|
||||
if isinstance(result, Error):
|
||||
self.error = result
|
||||
logging.error(
|
||||
"stopping task because of failed vmss: %s %s",
|
||||
self.scaleset_id,
|
||||
result,
|
||||
)
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.set_failed(result)
|
||||
return
|
||||
else:
|
||||
logging.info("creating scaleset: %s", self.scaleset_id)
|
||||
elif vmss.provisioning_state == "Creating":
|
||||
@ -750,10 +825,10 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.try_set_identity(vmss)
|
||||
else:
|
||||
logging.info("scaleset running: %s", self.scaleset_id)
|
||||
error = self.try_set_identity(vmss)
|
||||
if error:
|
||||
self.state = ScalesetState.creation_failed
|
||||
self.error = error
|
||||
identity_result = self.try_set_identity(vmss)
|
||||
if identity_result:
|
||||
self.set_failed(identity_result)
|
||||
return
|
||||
else:
|
||||
self.state = ScalesetState.running
|
||||
self.save()
|
||||
@ -843,8 +918,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.delete_nodes(to_delete)
|
||||
for node in to_delete:
|
||||
node.set_halt()
|
||||
node.state = NodeState.halt
|
||||
node.save()
|
||||
|
||||
if to_reimage:
|
||||
self.reimage_nodes(to_reimage)
|
||||
@ -967,6 +1040,17 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
node.reimage_queued = True
|
||||
node.save()
|
||||
|
||||
def set_shutdown(self, now: bool) -> None:
|
||||
if self.state in [ScalesetState.halt, ScalesetState.shutdown]:
|
||||
return
|
||||
|
||||
if now:
|
||||
self.state = ScalesetState.halt
|
||||
else:
|
||||
self.state = ScalesetState.shutdown
|
||||
|
||||
self.save()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
size = get_vmss_size(self.scaleset_id)
|
||||
logging.info("scaleset shutdown: %s (current size: %s)", self.scaleset_id, size)
|
||||
@ -977,7 +1061,6 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
self.halt()
|
||||
|
||||
def halt(self) -> None:
|
||||
self.state = ScalesetState.halt
|
||||
ScalesetShrinkQueue(self.scaleset_id).delete()
|
||||
|
||||
for node in Node.search_states(scaleset_id=self.scaleset_id):
|
||||
@ -1050,8 +1133,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
|
||||
pool = Pool.get_by_name(self.pool_name)
|
||||
if isinstance(pool, Error):
|
||||
self.error = pool
|
||||
self.halt()
|
||||
self.set_failed(pool)
|
||||
return
|
||||
|
||||
logging.debug("updating scaleset configs: %s", self.scaleset_id)
|
||||
@ -1068,6 +1150,12 @@ class Scaleset(BASE_SCALESET, ORMMixin):
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("pool_name", "scaleset_id")
|
||||
|
||||
def delete(self) -> None:
|
||||
super().delete()
|
||||
send_event(
|
||||
EventScalesetDeleted(scaleset_id=self.scaleset_id, pool_name=self.pool_name)
|
||||
)
|
||||
|
||||
|
||||
class ShrinkEntry(BaseModel):
|
||||
shrink_id: UUID = Field(default_factory=uuid4)
|
||||
|
@ -8,7 +8,8 @@ import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from azure.mgmt.compute.models import VirtualMachine
|
||||
from onefuzztypes.enums import VmState
|
||||
from onefuzztypes.enums import ErrorCode, VmState
|
||||
from onefuzztypes.events import EventProxyCreated, EventProxyDeleted, EventProxyFailed
|
||||
from onefuzztypes.models import (
|
||||
Authentication,
|
||||
Error,
|
||||
@ -26,8 +27,9 @@ from .azure.ip import get_public_ip
|
||||
from .azure.queue import get_queue_sas
|
||||
from .azure.storage import StorageType
|
||||
from .azure.vm import VM
|
||||
from .events import send_event
|
||||
from .extension import proxy_manager_extensions
|
||||
from .orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from .orm import ORMMixin, QueryFilter
|
||||
from .proxy_forward import ProxyForward
|
||||
|
||||
PROXY_SKU = "Standard_B2s"
|
||||
@ -41,7 +43,7 @@ class Proxy(ORMMixin):
|
||||
state: VmState = Field(default=VmState.init)
|
||||
auth: Authentication = Field(default_factory=build_auth)
|
||||
ip: Optional[str]
|
||||
error: Optional[str]
|
||||
error: Optional[Error]
|
||||
version: str = Field(default=__version__)
|
||||
heartbeat: Optional[ProxyHeartbeat]
|
||||
|
||||
@ -49,14 +51,6 @@ class Proxy(ORMMixin):
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("region", None)
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"region": ...,
|
||||
"state": ...,
|
||||
"ip": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def get_vm(self) -> VM:
|
||||
vm = VM(
|
||||
name="proxy-%s" % self.region,
|
||||
@ -72,42 +66,59 @@ class Proxy(ORMMixin):
|
||||
vm_data = vm.get()
|
||||
if vm_data:
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm)
|
||||
self.set_provision_failed(vm_data)
|
||||
return
|
||||
else:
|
||||
self.save_proxy_config()
|
||||
self.state = VmState.extensions_launch
|
||||
else:
|
||||
result = vm.create()
|
||||
if isinstance(result, Error):
|
||||
self.error = repr(result)
|
||||
self.state = VmState.stopping
|
||||
self.set_failed(result)
|
||||
return
|
||||
self.save()
|
||||
|
||||
def set_failed(self, vm_data: VirtualMachine) -> None:
|
||||
logging.error("vm failed to provision: %s", vm_data.name)
|
||||
def set_provision_failed(self, vm_data: VirtualMachine) -> None:
|
||||
errors = ["provisioning failed"]
|
||||
for status in vm_data.instance_view.statuses:
|
||||
if status.level.name.lower() == "error":
|
||||
logging.error(
|
||||
"vm status: %s %s %s %s",
|
||||
vm_data.name,
|
||||
status.code,
|
||||
status.display_status,
|
||||
status.message,
|
||||
errors.append(
|
||||
f"code:{status.code} status:{status.display_status} "
|
||||
f"message:{status.message}"
|
||||
)
|
||||
self.state = VmState.vm_allocation_failed
|
||||
|
||||
self.set_failed(
|
||||
Error(
|
||||
code=ErrorCode.PROXY_FAILED,
|
||||
errors=errors,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
def set_failed(self, error: Error) -> None:
|
||||
if self.error is not None:
|
||||
return
|
||||
|
||||
logging.error("proxy vm failed: %s - %s", self.region, error)
|
||||
send_event(EventProxyFailed(region=self.region, error=error))
|
||||
self.error = error
|
||||
self.state = VmState.stopping
|
||||
self.save()
|
||||
|
||||
def extensions_launch(self) -> None:
|
||||
vm = self.get_vm()
|
||||
vm_data = vm.get()
|
||||
if not vm_data:
|
||||
logging.error("Azure VM does not exist: %s", vm.name)
|
||||
self.state = VmState.stopping
|
||||
self.save()
|
||||
self.set_failed(
|
||||
Error(
|
||||
code=ErrorCode.PROXY_FAILED,
|
||||
errors=["azure not able to find vm"],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if vm_data.provisioning_state == "Failed":
|
||||
self.set_failed(vm_data)
|
||||
self.save()
|
||||
self.set_provision_failed(vm_data)
|
||||
return
|
||||
|
||||
ip = get_public_ip(vm_data.network_profile.network_interfaces[0].id)
|
||||
@ -119,9 +130,8 @@ class Proxy(ORMMixin):
|
||||
extensions = proxy_manager_extensions(self.region)
|
||||
result = vm.add_extensions(extensions)
|
||||
if isinstance(result, Error):
|
||||
logging.error("vm extension failed: %s", repr(result))
|
||||
self.error = repr(result)
|
||||
self.state = VmState.stopping
|
||||
self.set_failed(result)
|
||||
return
|
||||
elif result:
|
||||
self.state = VmState.running
|
||||
|
||||
@ -231,4 +241,9 @@ class Proxy(ORMMixin):
|
||||
|
||||
proxy = Proxy(region=region)
|
||||
proxy.save()
|
||||
send_event(EventProxyCreated(region=region))
|
||||
return proxy
|
||||
|
||||
def delete(self) -> None:
|
||||
super().delete()
|
||||
send_event(EventProxyDeleted(region=self.region))
|
||||
|
@ -7,6 +7,7 @@ import json
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from memoization import cached
|
||||
from onefuzztypes.models import Report
|
||||
from onefuzztypes.primitives import Container
|
||||
from pydantic import ValidationError
|
||||
@ -16,7 +17,7 @@ from .azure.storage import StorageType
|
||||
|
||||
|
||||
def parse_report(
|
||||
content: Union[str, bytes], metadata: Optional[str] = None
|
||||
content: Union[str, bytes], file_path: Optional[str] = None
|
||||
) -> Optional[Report]:
|
||||
if isinstance(content, bytes):
|
||||
try:
|
||||
@ -24,7 +25,7 @@ def parse_report(
|
||||
except UnicodeDecodeError as err:
|
||||
logging.error(
|
||||
"unable to parse report (%s): unicode decode of report failed - %s",
|
||||
metadata,
|
||||
file_path,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
@ -33,28 +34,30 @@ def parse_report(
|
||||
data = json.loads(content)
|
||||
except json.decoder.JSONDecodeError as err:
|
||||
logging.error(
|
||||
"unable to parse report (%s): json decoding failed - %s", metadata, err
|
||||
"unable to parse report (%s): json decoding failed - %s", file_path, err
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
entry = Report.parse_obj(data)
|
||||
except ValidationError as err:
|
||||
logging.error("unable to parse report (%s): %s", metadata, err)
|
||||
logging.error("unable to parse report (%s): %s", file_path, err)
|
||||
return None
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
# cache the last 1000 reports
|
||||
@cached(max_size=1000)
|
||||
def get_report(container: Container, filename: str) -> Optional[Report]:
|
||||
metadata = "/".join([container, filename])
|
||||
file_path = "/".join([container, filename])
|
||||
if not filename.endswith(".json"):
|
||||
logging.error("get_report invalid extension: %s", metadata)
|
||||
logging.error("get_report invalid extension: %s", file_path)
|
||||
return None
|
||||
|
||||
blob = get_blob(container, filename, StorageType.corpus)
|
||||
if blob is None:
|
||||
logging.error("get_report invalid blob: %s", metadata)
|
||||
logging.error("get_report invalid blob: %s", file_path)
|
||||
return None
|
||||
|
||||
return parse_report(blob, metadata=metadata)
|
||||
return parse_report(blob, file_path=file_path)
|
||||
|
@ -9,22 +9,23 @@ from typing import List, Optional, Tuple, Union
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.enums import ErrorCode, TaskState
|
||||
from onefuzztypes.events import (
|
||||
EventTaskCreated,
|
||||
EventTaskFailed,
|
||||
EventTaskStateUpdated,
|
||||
EventTaskStopped,
|
||||
)
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
|
||||
from onefuzztypes.webhooks import (
|
||||
WebhookEventTaskCreated,
|
||||
WebhookEventTaskFailed,
|
||||
WebhookEventTaskStopped,
|
||||
)
|
||||
|
||||
from ..azure.image import get_os
|
||||
from ..azure.queue import create_queue, delete_queue
|
||||
from ..azure.storage import StorageType
|
||||
from ..events import send_event
|
||||
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
|
||||
from ..pools import Node, Pool, Scaleset
|
||||
from ..proxy_forward import ProxyForward
|
||||
from ..webhooks import Webhook
|
||||
|
||||
|
||||
class Task(BASE_TASK, ORMMixin):
|
||||
@ -58,8 +59,8 @@ class Task(BASE_TASK, ORMMixin):
|
||||
raise Exception("task must have vm or pool")
|
||||
task = cls(config=config, job_id=job_id, os=os, user_info=user_info)
|
||||
task.save()
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskCreated(
|
||||
send_event(
|
||||
EventTaskCreated(
|
||||
job_id=task.job_id,
|
||||
task_id=task.task_id,
|
||||
config=config,
|
||||
@ -116,18 +117,9 @@ class Task(BASE_TASK, ORMMixin):
|
||||
"config": {"vm": {"count": ...}, "task": {"type": ...}},
|
||||
}
|
||||
|
||||
def event_include(self) -> Optional[MappingIntStrAny]:
|
||||
return {
|
||||
"job_id": ...,
|
||||
"task_id": ...,
|
||||
"state": ...,
|
||||
"error": ...,
|
||||
}
|
||||
|
||||
def init(self) -> None:
|
||||
create_queue(self.task_id, StorageType.corpus)
|
||||
self.state = TaskState.waiting
|
||||
self.save()
|
||||
self.set_state(TaskState.waiting)
|
||||
|
||||
def stopping(self) -> None:
|
||||
# TODO: we need to 'unschedule' this task from the existing pools
|
||||
@ -136,8 +128,7 @@ class Task(BASE_TASK, ORMMixin):
|
||||
ProxyForward.remove_forward(self.task_id)
|
||||
delete_queue(str(self.task_id), StorageType.corpus)
|
||||
Node.stop_task(self.task_id)
|
||||
self.state = TaskState.stopped
|
||||
self.save()
|
||||
self.set_state(TaskState.stopped, send=False)
|
||||
|
||||
@classmethod
|
||||
def search_states(
|
||||
@ -195,10 +186,9 @@ class Task(BASE_TASK, ORMMixin):
|
||||
)
|
||||
return
|
||||
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskStopped(
|
||||
self.set_state(TaskState.stopping, send=False)
|
||||
send_event(
|
||||
EventTaskStopped(
|
||||
job_id=self.job_id, task_id=self.task_id, user_info=self.user_info
|
||||
)
|
||||
)
|
||||
@ -211,11 +201,10 @@ class Task(BASE_TASK, ORMMixin):
|
||||
return
|
||||
|
||||
self.error = error
|
||||
self.state = TaskState.stopping
|
||||
self.save()
|
||||
self.set_state(TaskState.stopping, send=False)
|
||||
|
||||
Webhook.send_event(
|
||||
WebhookEventTaskFailed(
|
||||
send_event(
|
||||
EventTaskFailed(
|
||||
job_id=self.job_id,
|
||||
task_id=self.task_id,
|
||||
error=error,
|
||||
@ -287,7 +276,6 @@ class Task(BASE_TASK, ORMMixin):
|
||||
self.end_time = datetime.utcnow() + timedelta(
|
||||
hours=self.config.task.duration
|
||||
)
|
||||
self.save()
|
||||
|
||||
from ..jobs import Job
|
||||
|
||||
@ -298,3 +286,22 @@ class Task(BASE_TASK, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, str]:
|
||||
return ("job_id", "task_id")
|
||||
|
||||
def set_state(self, state: TaskState, send: bool = True) -> None:
|
||||
if self.state == state:
|
||||
return
|
||||
|
||||
self.state = state
|
||||
if self.state in [TaskState.running, TaskState.setting_up]:
|
||||
self.on_start()
|
||||
|
||||
self.save()
|
||||
|
||||
send_event(
|
||||
EventTaskStateUpdated(
|
||||
job_id=self.job_id,
|
||||
task_id=self.task_id,
|
||||
state=self.state,
|
||||
end_time=self.end_time,
|
||||
)
|
||||
)
|
||||
|
@ -235,8 +235,7 @@ def schedule_tasks() -> None:
|
||||
if schedule_workset(work_set, bucket_config.pool, bucket_config.count):
|
||||
for work_unit in work_set.work_units:
|
||||
task = tasks_by_id[work_unit.task_id]
|
||||
task.state = TaskState.scheduled
|
||||
task.save()
|
||||
task.set_state(TaskState.scheduled)
|
||||
seen.add(task.task_id)
|
||||
|
||||
not_ready_count = len(tasks) - len(seen)
|
||||
|
@ -8,21 +8,15 @@ import hmac
|
||||
import logging
|
||||
from hashlib import sha512
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import requests
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode, WebhookEventType, WebhookMessageState
|
||||
from onefuzztypes.enums import ErrorCode, WebhookMessageState
|
||||
from onefuzztypes.events import Event, EventMessage, EventPing, EventType
|
||||
from onefuzztypes.models import Error, Result
|
||||
from onefuzztypes.webhooks import Webhook as BASE_WEBHOOK
|
||||
from onefuzztypes.webhooks import (
|
||||
WebhookEvent,
|
||||
WebhookEventPing,
|
||||
WebhookEventTaskCreated,
|
||||
WebhookEventTaskFailed,
|
||||
WebhookEventTaskStopped,
|
||||
WebhookMessage,
|
||||
)
|
||||
from onefuzztypes.webhooks import WebhookMessage
|
||||
from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -140,34 +134,18 @@ class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
|
||||
)
|
||||
|
||||
|
||||
def get_event_type(event: WebhookEvent) -> WebhookEventType:
|
||||
events = {
|
||||
WebhookEventTaskCreated: WebhookEventType.task_created,
|
||||
WebhookEventTaskFailed: WebhookEventType.task_failed,
|
||||
WebhookEventTaskStopped: WebhookEventType.task_stopped,
|
||||
WebhookEventPing: WebhookEventType.ping,
|
||||
}
|
||||
|
||||
for event_class in events:
|
||||
if isinstance(event, event_class):
|
||||
return events[event_class]
|
||||
|
||||
raise NotImplementedError("unsupported event type: %s" % event)
|
||||
|
||||
|
||||
class Webhook(BASE_WEBHOOK, ORMMixin):
|
||||
@classmethod
|
||||
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
||||
return ("webhook_id", "name")
|
||||
|
||||
@classmethod
|
||||
def send_event(cls, event: WebhookEvent) -> None:
|
||||
event_type = get_event_type(event)
|
||||
def send_event(cls, event_message: EventMessage) -> None:
|
||||
for webhook in get_webhooks_cached():
|
||||
if event_type not in webhook.event_types:
|
||||
if event_message.event_type not in webhook.event_types:
|
||||
continue
|
||||
|
||||
webhook._add_event(event_type, event)
|
||||
webhook._add_event(event_message)
|
||||
|
||||
@classmethod
|
||||
def get_by_id(cls, webhook_id: UUID) -> Result["Webhook"]:
|
||||
@ -185,18 +163,19 @@ class Webhook(BASE_WEBHOOK, ORMMixin):
|
||||
webhook = webhooks[0]
|
||||
return webhook
|
||||
|
||||
def _add_event(self, event_type: WebhookEventType, event: WebhookEvent) -> None:
|
||||
def _add_event(self, event_message: EventMessage) -> None:
|
||||
message = WebhookMessageLog(
|
||||
webhook_id=self.webhook_id,
|
||||
event_type=event_type,
|
||||
event=event,
|
||||
event_id=event_message.event_id,
|
||||
event_type=event_message.event_type,
|
||||
event=event_message.event,
|
||||
)
|
||||
message.save()
|
||||
message.queue_webhook()
|
||||
|
||||
def ping(self) -> WebhookEventPing:
|
||||
ping = WebhookEventPing()
|
||||
self._add_event(WebhookEventType.ping, ping)
|
||||
def ping(self) -> EventPing:
|
||||
ping = EventPing(ping_id=uuid4())
|
||||
self._add_event(EventMessage(event_type=EventType.ping, event=ping))
|
||||
return ping
|
||||
|
||||
def send(self, message_log: WebhookMessageLog) -> bool:
|
||||
@ -228,8 +207,8 @@ def build_message(
|
||||
*,
|
||||
webhook_id: UUID,
|
||||
event_id: UUID,
|
||||
event_type: WebhookEventType,
|
||||
event: WebhookEvent,
|
||||
event_type: EventType,
|
||||
event: Event,
|
||||
secret_token: Optional[str] = None,
|
||||
) -> Tuple[bytes, Optional[str]]:
|
||||
data = (
|
||||
|
Reference in New Issue
Block a user