work around issue with discriminated typed unions (#939)

We're experiencing a bug where Unions of sub-models are getting downcast, which causes a loss of information.  

As an example, EventScalesetCreated was getting downcast to EventScalesetDeleted.  I have not figured out why, nor can I replicate it locally to minimize the bug send upstream, but I was able to reliably replicate it on the service.

While working through this issue, I noticed that deserialization of SignalR events was frequently wrong, leaving things like tasks as "init" in `status top`.

Both of these issues are Unions of models with a type field, so it's likely these are related.
This commit is contained in:
bmc-msft
2021-06-02 12:40:58 -04:00
committed by GitHub
parent 3d191c3c5d
commit a92c84d42a
7 changed files with 81 additions and 39 deletions

View File

@ -24,7 +24,7 @@ def get_events() -> Optional[str]:
for _ in range(5):
try:
event = EVENTS.get(block=False)
events.append(json.loads(event.json(exclude_none=True)))
events.append(json.loads(event))
EVENTS.task_done()
except Empty:
break
@ -36,13 +36,13 @@ def get_events() -> Optional[str]:
def log_event(event: Event, event_type: EventType) -> None:
scrubbed_event = filter_event(event, event_type)
scrubbed_event = filter_event(event)
logging.info(
"sending event: %s - %s", event_type, scrubbed_event.json(exclude_none=True)
)
def filter_event(event: Event, event_type: EventType) -> BaseModel:
def filter_event(event: Event) -> BaseModel:
clone_event = event.copy(deep=True)
filtered_event = filter_event_recurse(clone_event)
return filtered_event
@ -73,12 +73,18 @@ def filter_event_recurse(entry: BaseModel) -> BaseModel:
def send_event(event: Event) -> None:
event_type = get_event_type(event)
log_event(event, event_type)
event_message = EventMessage(
event_type=event_type,
event=event,
event=event.copy(deep=True),
instance_id=get_instance_id(),
instance_name=get_instance_name(),
)
EVENTS.put(event_message)
# work around odd bug with Event Message creation. See PR 939
if event_message.event != event:
event_message.event = event.copy(deep=True)
EVENTS.put(event_message.json())
Webhook.send_event(event_message)
log_event(event, event_type)

View File

@ -23,7 +23,7 @@ azure-storage-queue==12.1.6
jinja2~=2.11.3
msrestazure~=0.6.3
opencensus-ext-azure~=1.0.2
pydantic~=1.8.1 --no-binary=pydantic
pydantic==1.8.2 --no-binary=pydantic
PyJWT~=1.7.1
requests~=2.25.1
memoization~=0.3.1

View File

@ -8,7 +8,7 @@ import unittest
from uuid import uuid4
from onefuzztypes.enums import ContainerType, TaskType
from onefuzztypes.events import EventTaskCreated, get_event_type
from onefuzztypes.events import EventTaskCreated
from onefuzztypes.models import (
TaskConfig,
TaskContainers,
@ -65,9 +65,7 @@ class TestUserInfoFilter(unittest.TestCase):
user_info=None,
)
test_event_type = get_event_type(test_event)
scrubbed_test_event = filter_event(test_event, test_event_type)
scrubbed_test_event = filter_event(test_event)
self.assertEqual(scrubbed_test_event, control_test_event)

View File

@ -3,6 +3,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import logging
from datetime import datetime
from enum import Enum
@ -15,7 +16,6 @@ from onefuzztypes.events import (
EventFileAdded,
EventJobCreated,
EventJobStopped,
EventMessage,
EventNodeCreated,
EventNodeDeleted,
EventNodeStateUpdated,
@ -26,6 +26,7 @@ from onefuzztypes.events import (
EventTaskStateUpdated,
EventTaskStopped,
EventType,
parse_event_message,
)
from onefuzztypes.models import (
Job,
@ -152,31 +153,43 @@ class TopCache:
self.add_files_set(name, set(files.files))
def add_message(self, message: EventMessage) -> None:
events = {
EventPoolCreated: lambda x: self.pool_created(x),
EventPoolDeleted: lambda x: self.pool_deleted(x),
EventTaskCreated: lambda x: self.task_created(x),
EventTaskStopped: lambda x: self.task_stopped(x),
EventTaskFailed: lambda x: self.task_stopped(x),
EventTaskStateUpdated: lambda x: self.task_state_updated(x),
EventJobCreated: lambda x: self.job_created(x),
EventJobStopped: lambda x: self.job_stopped(x),
EventNodeStateUpdated: lambda x: self.node_state_updated(x),
EventNodeCreated: lambda x: self.node_created(x),
EventNodeDeleted: lambda x: self.node_deleted(x),
EventCrashReported: lambda x: self.file_added(x),
EventFileAdded: lambda x: self.file_added(x),
}
def add_message(self, message_obj: Any) -> None:
message = parse_event_message(message_obj)
for event_cls in events:
if isinstance(message.event, event_cls):
events[event_cls](message.event)
event = message.event
if isinstance(event, EventPoolCreated):
self.pool_created(event)
elif isinstance(event, EventPoolDeleted):
self.pool_deleted(event)
elif isinstance(event, EventTaskCreated):
self.task_created(event)
elif isinstance(event, EventTaskStopped):
self.task_stopped(event)
elif isinstance(event, EventTaskFailed):
self.task_failed(event)
elif isinstance(event, EventTaskStateUpdated):
self.task_state_updated(event)
elif isinstance(event, EventJobCreated):
self.job_created(event)
elif isinstance(event, EventJobStopped):
self.job_stopped(event)
elif isinstance(event, EventNodeStateUpdated):
self.node_state_updated(event)
elif isinstance(event, EventNodeCreated):
self.node_created(event)
elif isinstance(event, EventNodeDeleted):
self.node_deleted(event)
elif isinstance(event, (EventCrashReported, EventFileAdded)):
self.file_added(event)
self.last_update = datetime.now()
messages = [x for x in self.messages][-99:]
messages += [
(datetime.now(), message.event_type, message.event.json(exclude_none=True))
(
datetime.now(),
message.event_type,
json.dumps(message_obj, sort_keys=True),
)
]
self.messages = messages
@ -301,6 +314,10 @@ class TopCache:
if event.task_id in self.tasks:
del self.tasks[event.task_id]
def task_failed(self, event: EventTaskFailed) -> None:
if event.task_id in self.tasks:
del self.tasks[event.task_id]
def render_tasks(self) -> List:
results = []
for task in self.tasks.values():
@ -352,6 +369,11 @@ class TopCache:
if event.job_id in self.jobs:
del self.jobs[event.job_id]
to_remove = [x.task_id for x in self.tasks.values() if x.job_id == event.job_id]
for task_id in to_remove:
del self.tasks[task_id]
def render_jobs(self) -> List[Tuple]:
results: List[Tuple] = []

View File

@ -9,8 +9,6 @@ from queue import PriorityQueue
from threading import Thread
from typing import Any, Optional
from onefuzztypes.events import EventMessage
from .cache import JobFilter, TopCache
from .signalr import Stream
from .top_view import render
@ -50,8 +48,7 @@ class Top:
def handler(self, message: Any) -> None:
for event_raw in message:
message = EventMessage.parse_obj(event_raw)
self.cache.add_message(message)
self.cache.add_message(event_raw)
def setup(self) -> Stream:
client = Stream(self.onefuzz, self.logger)

View File

@ -5,7 +5,7 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
@ -91,7 +91,7 @@ class EventTaskHeartbeat(BaseEvent):
config: TaskConfig
class EventPing(BaseResponse):
class EventPing(BaseEvent, BaseResponse):
ping_id: UUID
@ -300,3 +300,22 @@ class EventMessage(BaseEvent):
event: Event
instance_id: UUID
instance_name: str
# because Pydantic does not yet have discriminated union types yet, parse events
# by hand. https://github.com/samuelcolvin/pydantic/issues/619
def parse_event_message(data: Dict[str, Any]) -> EventMessage:
instance_id = UUID(data["instance_id"])
instance_name = data["instance_name"]
event_id = UUID(data["event_id"])
event_type = EventType[data["event_type"]]
# mypy incorrectly identifies this as having not supported parse_obj yet
event = EventTypeMap[event_type].parse_obj(data["event"]) # type: ignore
return EventMessage(
event_id=event_id,
event_type=event_type,
event=event,
instance_id=instance_id,
instance_name=instance_name,
)

View File

@ -1 +1 @@
pydantic~=1.8.1 --no-binary=pydantic
pydantic==1.8.2 --no-binary=pydantic