work around issue with discriminated typed unions (#939)

We're experiencing a bug where Unions of sub-models are getting downcast, which causes a loss of information.  

As an example, EventScalesetCreated was getting downcast to EventScalesetDeleted.  I have not figured out why, nor can I replicate it locally to minimize the bug send upstream, but I was able to reliably replicate it on the service.

While working through this issue, I noticed that deserialization of SignalR events was frequently wrong, leaving things like tasks as "init" in `status top`.

Both of these issues are Unions of models with a type field, so it's likely these are related.
This commit is contained in:
bmc-msft
2021-06-02 12:40:58 -04:00
committed by GitHub
parent 3d191c3c5d
commit a92c84d42a
7 changed files with 81 additions and 39 deletions

View File

@ -24,7 +24,7 @@ def get_events() -> Optional[str]:
for _ in range(5): for _ in range(5):
try: try:
event = EVENTS.get(block=False) event = EVENTS.get(block=False)
events.append(json.loads(event.json(exclude_none=True))) events.append(json.loads(event))
EVENTS.task_done() EVENTS.task_done()
except Empty: except Empty:
break break
@ -36,13 +36,13 @@ def get_events() -> Optional[str]:
def log_event(event: Event, event_type: EventType) -> None: def log_event(event: Event, event_type: EventType) -> None:
scrubbed_event = filter_event(event, event_type) scrubbed_event = filter_event(event)
logging.info( logging.info(
"sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True) "sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True)
) )
def filter_event(event: Event, event_type: EventType) -> BaseModel: def filter_event(event: Event) -> BaseModel:
clone_event = event.copy(deep=True) clone_event = event.copy(deep=True)
filtered_event = filter_event_recurse(clone_event) filtered_event = filter_event_recurse(clone_event)
return filtered_event return filtered_event
@ -73,12 +73,18 @@ def filter_event_recurse(entry: BaseModel) -> BaseModel:
def send_event(event: Event) -> None: def send_event(event: Event) -> None:
event_type = get_event_type(event) event_type = get_event_type(event)
log_event(event, event_type)
event_message = EventMessage( event_message = EventMessage(
event_type=event_type, event_type=event_type,
event=event, event=event.copy(deep=True),
instance_id=get_instance_id(), instance_id=get_instance_id(),
instance_name=get_instance_name(), instance_name=get_instance_name(),
) )
EVENTS.put(event_message)
# work around odd bug with Event Message creation. See PR 939
if event_message.event != event:
event_message.event = event.copy(deep=True)
EVENTS.put(event_message.json())
Webhook.send_event(event_message) Webhook.send_event(event_message)
log_event(event, event_type)

View File

@ -23,7 +23,7 @@ azure-storage-queue==12.1.6
jinja2~=2.11.3 jinja2~=2.11.3
msrestazure~=0.6.3 msrestazure~=0.6.3
opencensus-ext-azure~=1.0.2 opencensus-ext-azure~=1.0.2
pydantic~=1.8.1 --no-binary=pydantic pydantic==1.8.2 --no-binary=pydantic
PyJWT~=1.7.1 PyJWT~=1.7.1
requests~=2.25.1 requests~=2.25.1
memoization~=0.3.1 memoization~=0.3.1

View File

@ -8,7 +8,7 @@ import unittest
from uuid import uuid4 from uuid import uuid4
from onefuzztypes.enums import ContainerType, TaskType from onefuzztypes.enums import ContainerType, TaskType
from onefuzztypes.events import EventTaskCreated, get_event_type from onefuzztypes.events import EventTaskCreated
from onefuzztypes.models import ( from onefuzztypes.models import (
TaskConfig, TaskConfig,
TaskContainers, TaskContainers,
@ -65,9 +65,7 @@ class TestUserInfoFilter(unittest.TestCase):
user_info=None, user_info=None,
) )
test_event_type = get_event_type(test_event) scrubbed_test_event = filter_event(test_event)
scrubbed_test_event = filter_event(test_event, test_event_type)
self.assertEqual(scrubbed_test_event, control_test_event) self.assertEqual(scrubbed_test_event, control_test_event)

View File

@ -3,6 +3,7 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import json
import logging import logging
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -15,7 +16,6 @@ from onefuzztypes.events import (
EventFileAdded, EventFileAdded,
EventJobCreated, EventJobCreated,
EventJobStopped, EventJobStopped,
EventMessage,
EventNodeCreated, EventNodeCreated,
EventNodeDeleted, EventNodeDeleted,
EventNodeStateUpdated, EventNodeStateUpdated,
@ -26,6 +26,7 @@ from onefuzztypes.events import (
EventTaskStateUpdated, EventTaskStateUpdated,
EventTaskStopped, EventTaskStopped,
EventType, EventType,
parse_event_message,
) )
from onefuzztypes.models import ( from onefuzztypes.models import (
Job, Job,
@ -152,31 +153,43 @@ class TopCache:
self.add_files_set(name, set(files.files)) self.add_files_set(name, set(files.files))
def add_message(self, message: EventMessage) -> None: def add_message(self, message_obj: Any) -> None:
events = { message = parse_event_message(message_obj)
EventPoolCreated: lambda x: self.pool_created(x),
EventPoolDeleted: lambda x: self.pool_deleted(x),
EventTaskCreated: lambda x: self.task_created(x),
EventTaskStopped: lambda x: self.task_stopped(x),
EventTaskFailed: lambda x: self.task_stopped(x),
EventTaskStateUpdated: lambda x: self.task_state_updated(x),
EventJobCreated: lambda x: self.job_created(x),
EventJobStopped: lambda x: self.job_stopped(x),
EventNodeStateUpdated: lambda x: self.node_state_updated(x),
EventNodeCreated: lambda x: self.node_created(x),
EventNodeDeleted: lambda x: self.node_deleted(x),
EventCrashReported: lambda x: self.file_added(x),
EventFileAdded: lambda x: self.file_added(x),
}
for event_cls in events: event = message.event
if isinstance(message.event, event_cls): if isinstance(event, EventPoolCreated):
events[event_cls](message.event) self.pool_created(event)
elif isinstance(event, EventPoolDeleted):
self.pool_deleted(event)
elif isinstance(event, EventTaskCreated):
self.task_created(event)
elif isinstance(event, EventTaskStopped):
self.task_stopped(event)
elif isinstance(event, EventTaskFailed):
self.task_failed(event)
elif isinstance(event, EventTaskStateUpdated):
self.task_state_updated(event)
elif isinstance(event, EventJobCreated):
self.job_created(event)
elif isinstance(event, EventJobStopped):
self.job_stopped(event)
elif isinstance(event, EventNodeStateUpdated):
self.node_state_updated(event)
elif isinstance(event, EventNodeCreated):
self.node_created(event)
elif isinstance(event, EventNodeDeleted):
self.node_deleted(event)
elif isinstance(event, (EventCrashReported, EventFileAdded)):
self.file_added(event)
self.last_update = datetime.now() self.last_update = datetime.now()
messages = [x for x in self.messages][-99:] messages = [x for x in self.messages][-99:]
messages += [ messages += [
(datetime.now(), message.event_type, message.event.json(exclude_none=True)) (
datetime.now(),
message.event_type,
json.dumps(message_obj, sort_keys=True),
)
] ]
self.messages = messages self.messages = messages
@ -301,6 +314,10 @@ class TopCache:
if event.task_id in self.tasks: if event.task_id in self.tasks:
del self.tasks[event.task_id] del self.tasks[event.task_id]
def task_failed(self, event: EventTaskFailed) -> None:
if event.task_id in self.tasks:
del self.tasks[event.task_id]
def render_tasks(self) -> List: def render_tasks(self) -> List:
results = [] results = []
for task in self.tasks.values(): for task in self.tasks.values():
@ -352,6 +369,11 @@ class TopCache:
if event.job_id in self.jobs: if event.job_id in self.jobs:
del self.jobs[event.job_id] del self.jobs[event.job_id]
to_remove = [x.task_id for x in self.tasks.values() if x.job_id == event.job_id]
for task_id in to_remove:
del self.tasks[task_id]
def render_jobs(self) -> List[Tuple]: def render_jobs(self) -> List[Tuple]:
results: List[Tuple] = [] results: List[Tuple] = []

View File

@ -9,8 +9,6 @@ from queue import PriorityQueue
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, Optional
from onefuzztypes.events import EventMessage
from .cache import JobFilter, TopCache from .cache import JobFilter, TopCache
from .signalr import Stream from .signalr import Stream
from .top_view import render from .top_view import render
@ -50,8 +48,7 @@ class Top:
def handler(self, message: Any) -> None: def handler(self, message: Any) -> None:
for event_raw in message: for event_raw in message:
message = EventMessage.parse_obj(event_raw) self.cache.add_message(event_raw)
self.cache.add_message(message)
def setup(self) -> Stream: def setup(self) -> Stream:
client = Stream(self.onefuzz, self.logger) client = Stream(self.onefuzz, self.logger)

View File

@ -5,7 +5,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -91,7 +91,7 @@ class EventTaskHeartbeat(BaseEvent):
config: TaskConfig config: TaskConfig
class EventPing(BaseResponse): class EventPing(BaseEvent, BaseResponse):
ping_id: UUID ping_id: UUID
@ -300,3 +300,22 @@ class EventMessage(BaseEvent):
event: Event event: Event
instance_id: UUID instance_id: UUID
instance_name: str instance_name: str
# because Pydantic does not yet have discriminated union types yet, parse events
# by hand. https://github.com/samuelcolvin/pydantic/issues/619
def parse_event_message(data: Dict[str, Any]) -> EventMessage:
instance_id = UUID(data["instance_id"])
instance_name = data["instance_name"]
event_id = UUID(data["event_id"])
event_type = EventType[data["event_type"]]
# mypy incorrectly identifies this as having not supported parse_obj yet
event = EventTypeMap[event_type].parse_obj(data["event"]) # type: ignore
return EventMessage(
event_id=event_id,
event_type=event_type,
event=event,
instance_id=instance_id,
instance_name=instance_name,
)

View File

@ -1 +1 @@
pydantic~=1.8.1 --no-binary=pydantic pydantic==1.8.2 --no-binary=pydantic