mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-19 04:58:09 +00:00
add monkeypatch to hotfix pydantic Union issues (#982)
Until Pydantic supports discriminated or "smart" unions, we need to work around the coercion issue impacting unions in our models. This reuses the "smart union" implementation from https://github.com/samuelcolvin/pydantic/pull/2092
This commit is contained in:
52
src/pytypes/onefuzztypes/_monkeypatch.py
Normal file
52
src/pytypes/onefuzztypes/_monkeypatch.py
Normal file
@ -0,0 +1,52 @@
|
||||
# TODO: Remove once `smart_union` like support is added to Pydantic
|
||||
#
|
||||
# Written by @PrettyWood
|
||||
# Code from https://github.com/samuelcolvin/pydantic/pull/2092
|
||||
#
|
||||
# Original project licensed under the MIT License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from pydantic.fields import ModelField
|
||||
from pydantic.typing import get_origin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.fields import LocStr, ValidateReturn
|
||||
from pydantic.types import ModelOrDc
|
||||
|
||||
upstream_validate_singleton = ModelField._validate_singleton
|
||||
|
||||
|
||||
# this is a direct port of the functionality from the PR discussed above, though
|
||||
# *all* unions are considered "smart" for our purposes.
|
||||
def wrap_validate_singleton(
|
||||
self: ModelField,
|
||||
v: Any,
|
||||
values: Dict[str, Any],
|
||||
loc: "LocStr",
|
||||
cls: Optional["ModelOrDc"],
|
||||
) -> "ValidateReturn":
|
||||
if self.sub_fields:
|
||||
if get_origin(self.type_) is Union:
|
||||
for field in self.sub_fields:
|
||||
if v.__class__ is field.outer_type_:
|
||||
return v, None
|
||||
for field in self.sub_fields:
|
||||
try:
|
||||
if isinstance(v, field.outer_type_):
|
||||
return v, None
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
return upstream_validate_singleton(self, v, values, loc, cls)
|
||||
|
||||
|
||||
ModelField._validate_singleton = wrap_validate_singleton # type: ignore
|
||||
|
||||
|
||||
# this should be included in any file that defines a pydantic model that uses a
|
||||
# Union and calls to it should be removed when Pydantic's smart union support
|
||||
# lands
|
||||
def _check_hotfix() -> None:
|
||||
if ModelField._validate_singleton != wrap_validate_singleton:
|
||||
raise Exception("pydantic Union hotfix not applied")
|
@ -10,6 +10,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ._monkeypatch import _check_hotfix
|
||||
from .enums import (
|
||||
OS,
|
||||
Architecture,
|
||||
@ -319,3 +320,6 @@ def parse_event_message(data: Dict[str, Any]) -> EventMessage:
|
||||
instance_id=instance_id,
|
||||
instance_name=instance_name,
|
||||
)
|
||||
|
||||
|
||||
_check_hotfix()
|
||||
|
@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
|
||||
from ._monkeypatch import _check_hotfix
|
||||
from .enums import OS, ContainerType, UserFieldOperation, UserFieldType
|
||||
from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers
|
||||
from .primitives import File
|
||||
@ -184,3 +185,6 @@ class JobTemplateGet(BaseRequest):
|
||||
|
||||
class JobTemplateRequestParameters(BaseRequest):
|
||||
user_fields: TemplateUserFields
|
||||
|
||||
|
||||
_check_hotfix()
|
||||
|
@ -10,6 +10,7 @@ from uuid import UUID, uuid4
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from ._monkeypatch import _check_hotfix
|
||||
from .consts import ONE_HOUR, SEVEN_DAYS
|
||||
from .enums import (
|
||||
OS,
|
||||
@ -842,3 +843,6 @@ class Task(BaseModel):
|
||||
events: Optional[List[TaskEventSummary]]
|
||||
nodes: Optional[List[NodeAssignment]]
|
||||
user_info: Optional[UserInfo]
|
||||
|
||||
|
||||
_check_hotfix()
|
||||
|
@ -8,6 +8,7 @@ from uuid import UUID
|
||||
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
|
||||
|
||||
from ._monkeypatch import _check_hotfix
|
||||
from .consts import ONE_HOUR, SEVEN_DAYS
|
||||
from .enums import (
|
||||
OS,
|
||||
@ -250,3 +251,6 @@ class WebhookUpdate(BaseModel):
|
||||
class NodeAddSshKey(BaseModel):
|
||||
machine_id: UUID
|
||||
public_key: str
|
||||
|
||||
|
||||
_check_hotfix()
|
||||
|
Reference in New Issue
Block a user