mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-18 20:58:06 +00:00
add monkeypatch to hotfix pydantic Union issues (#982)
Until Pydantic supports discriminated or "smart" unions, we need to work around the coercion issue impacting unions in our models. This reuses the "smart union" implementation from https://github.com/samuelcolvin/pydantic/pull/2092
This commit is contained in:
52
src/pytypes/onefuzztypes/_monkeypatch.py
Normal file
52
src/pytypes/onefuzztypes/_monkeypatch.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# TODO: Remove once `smart_union` like support is added to Pydantic
|
||||||
|
#
|
||||||
|
# Written by @PrettyWood
|
||||||
|
# Code from https://github.com/samuelcolvin/pydantic/pull/2092
|
||||||
|
#
|
||||||
|
# Original project licensed under the MIT License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
from pydantic.fields import ModelField
|
||||||
|
from pydantic.typing import get_origin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic.fields import LocStr, ValidateReturn
|
||||||
|
from pydantic.types import ModelOrDc
|
||||||
|
|
||||||
|
upstream_validate_singleton = ModelField._validate_singleton
|
||||||
|
|
||||||
|
|
||||||
|
# this is a direct port of the functionality from the PR discussed above, though
|
||||||
|
# *all* unions are considered "smart" for our purposes.
|
||||||
|
def wrap_validate_singleton(
|
||||||
|
self: ModelField,
|
||||||
|
v: Any,
|
||||||
|
values: Dict[str, Any],
|
||||||
|
loc: "LocStr",
|
||||||
|
cls: Optional["ModelOrDc"],
|
||||||
|
) -> "ValidateReturn":
|
||||||
|
if self.sub_fields:
|
||||||
|
if get_origin(self.type_) is Union:
|
||||||
|
for field in self.sub_fields:
|
||||||
|
if v.__class__ is field.outer_type_:
|
||||||
|
return v, None
|
||||||
|
for field in self.sub_fields:
|
||||||
|
try:
|
||||||
|
if isinstance(v, field.outer_type_):
|
||||||
|
return v, None
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return upstream_validate_singleton(self, v, values, loc, cls)
|
||||||
|
|
||||||
|
|
||||||
|
ModelField._validate_singleton = wrap_validate_singleton # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# this should be included in any file that defines a pydantic model that uses a
|
||||||
|
# Union and calls to it should be removed when Pydantic's smart union support
|
||||||
|
# lands
|
||||||
|
def _check_hotfix() -> None:
|
||||||
|
if ModelField._validate_singleton != wrap_validate_singleton:
|
||||||
|
raise Exception("pydantic Union hotfix not applied")
|
@ -10,6 +10,7 @@ from uuid import UUID, uuid4
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ._monkeypatch import _check_hotfix
|
||||||
from .enums import (
|
from .enums import (
|
||||||
OS,
|
OS,
|
||||||
Architecture,
|
Architecture,
|
||||||
@ -319,3 +320,6 @@ def parse_event_message(data: Dict[str, Any]) -> EventMessage:
|
|||||||
instance_id=instance_id,
|
instance_id=instance_id,
|
||||||
instance_name=instance_name,
|
instance_name=instance_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_check_hotfix()
|
||||||
|
@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator, validator
|
from pydantic import BaseModel, Field, root_validator, validator
|
||||||
|
|
||||||
|
from ._monkeypatch import _check_hotfix
|
||||||
from .enums import OS, ContainerType, UserFieldOperation, UserFieldType
|
from .enums import OS, ContainerType, UserFieldOperation, UserFieldType
|
||||||
from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers
|
from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers
|
||||||
from .primitives import File
|
from .primitives import File
|
||||||
@ -184,3 +185,6 @@ class JobTemplateGet(BaseRequest):
|
|||||||
|
|
||||||
class JobTemplateRequestParameters(BaseRequest):
|
class JobTemplateRequestParameters(BaseRequest):
|
||||||
user_fields: TemplateUserFields
|
user_fields: TemplateUserFields
|
||||||
|
|
||||||
|
|
||||||
|
_check_hotfix()
|
||||||
|
@ -10,6 +10,7 @@ from uuid import UUID, uuid4
|
|||||||
from pydantic import BaseModel, Field, root_validator, validator
|
from pydantic import BaseModel, Field, root_validator, validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from ._monkeypatch import _check_hotfix
|
||||||
from .consts import ONE_HOUR, SEVEN_DAYS
|
from .consts import ONE_HOUR, SEVEN_DAYS
|
||||||
from .enums import (
|
from .enums import (
|
||||||
OS,
|
OS,
|
||||||
@ -842,3 +843,6 @@ class Task(BaseModel):
|
|||||||
events: Optional[List[TaskEventSummary]]
|
events: Optional[List[TaskEventSummary]]
|
||||||
nodes: Optional[List[NodeAssignment]]
|
nodes: Optional[List[NodeAssignment]]
|
||||||
user_info: Optional[UserInfo]
|
user_info: Optional[UserInfo]
|
||||||
|
|
||||||
|
|
||||||
|
_check_hotfix()
|
||||||
|
@ -8,6 +8,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
|
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
|
||||||
|
|
||||||
|
from ._monkeypatch import _check_hotfix
|
||||||
from .consts import ONE_HOUR, SEVEN_DAYS
|
from .consts import ONE_HOUR, SEVEN_DAYS
|
||||||
from .enums import (
|
from .enums import (
|
||||||
OS,
|
OS,
|
||||||
@ -250,3 +251,6 @@ class WebhookUpdate(BaseModel):
|
|||||||
class NodeAddSshKey(BaseModel):
|
class NodeAddSshKey(BaseModel):
|
||||||
machine_id: UUID
|
machine_id: UUID
|
||||||
public_key: str
|
public_key: str
|
||||||
|
|
||||||
|
|
||||||
|
_check_hotfix()
|
||||||
|
Reference in New Issue
Block a user