add monkeypatch to hotfix pydantic Union issues (#982)

Until Pydantic supports discriminated or "smart" unions, we need to work around the coercion issue impacting unions in our models.

This reuses the "smart union" implementation from https://github.com/samuelcolvin/pydantic/pull/2092
This commit is contained in:
bmc-msft
2021-06-11 18:00:06 -04:00
committed by GitHub
parent a0bce1a538
commit c0b3a409e4
5 changed files with 68 additions and 0 deletions

View File

@ -0,0 +1,52 @@
# TODO: Remove once `smart_union` like support is added to Pydantic
#
# Written by @PrettyWood
# Code from https://github.com/samuelcolvin/pydantic/pull/2092
#
# Original project licensed under the MIT License.
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from pydantic.fields import ModelField
from pydantic.typing import get_origin
if TYPE_CHECKING:
from pydantic.fields import LocStr, ValidateReturn
from pydantic.types import ModelOrDc
upstream_validate_singleton = ModelField._validate_singleton
# this is a direct port of the functionality from the PR discussed above, though
# *all* unions are considered "smart" for our purposes.
def wrap_validate_singleton(
self: ModelField,
v: Any,
values: Dict[str, Any],
loc: "LocStr",
cls: Optional["ModelOrDc"],
) -> "ValidateReturn":
if self.sub_fields:
if get_origin(self.type_) is Union:
for field in self.sub_fields:
if v.__class__ is field.outer_type_:
return v, None
for field in self.sub_fields:
try:
if isinstance(v, field.outer_type_):
return v, None
except TypeError:
pass
return upstream_validate_singleton(self, v, values, loc, cls)
ModelField._validate_singleton = wrap_validate_singleton # type: ignore
# this should be included in any file that defines a pydantic model that uses a
# Union and calls to it should be removed when Pydantic's smart union support
# lands
def _check_hotfix() -> None:
if ModelField._validate_singleton != wrap_validate_singleton:
raise Exception("pydantic Union hotfix not applied")

View File

@ -10,6 +10,7 @@ from uuid import UUID, uuid4
from pydantic import BaseModel, Field
from ._monkeypatch import _check_hotfix
from .enums import (
OS,
Architecture,
@ -319,3 +320,6 @@ def parse_event_message(data: Dict[str, Any]) -> EventMessage:
instance_id=instance_id,
instance_name=instance_name,
)
_check_hotfix()

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator
from ._monkeypatch import _check_hotfix
from .enums import OS, ContainerType, UserFieldOperation, UserFieldType
from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers
from .primitives import File
@ -184,3 +185,6 @@ class JobTemplateGet(BaseRequest):
class JobTemplateRequestParameters(BaseRequest):
user_fields: TemplateUserFields
_check_hotfix()

View File

@ -10,6 +10,7 @@ from uuid import UUID, uuid4
from pydantic import BaseModel, Field, root_validator, validator
from pydantic.dataclasses import dataclass
from ._monkeypatch import _check_hotfix
from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import (
OS,
@ -842,3 +843,6 @@ class Task(BaseModel):
events: Optional[List[TaskEventSummary]]
nodes: Optional[List[NodeAssignment]]
user_info: Optional[UserInfo]
_check_hotfix()

View File

@ -8,6 +8,7 @@ from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
from ._monkeypatch import _check_hotfix
from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import (
OS,
@ -250,3 +251,6 @@ class WebhookUpdate(BaseModel):
class NodeAddSshKey(BaseModel):
machine_id: UUID
public_key: str
_check_hotfix()