add monkeypatch to hotfix pydantic Union issues (#982)

Until Pydantic supports discriminated or "smart" unions, we need to work around the coercion issue impacting unions in our models.

This reuses the "smart union" implementation from https://github.com/samuelcolvin/pydantic/pull/2092
This commit is contained in:
bmc-msft
2021-06-11 18:00:06 -04:00
committed by GitHub
parent a0bce1a538
commit c0b3a409e4
5 changed files with 68 additions and 0 deletions

View File

@ -0,0 +1,52 @@
# TODO: Remove once `smart_union` like support is added to Pydantic
#
# Written by @PrettyWood
# Code from https://github.com/samuelcolvin/pydantic/pull/2092
#
# Original project licensed under the MIT License.
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from pydantic.fields import ModelField
from pydantic.typing import get_origin
if TYPE_CHECKING:
from pydantic.fields import LocStr, ValidateReturn
from pydantic.types import ModelOrDc
upstream_validate_singleton = ModelField._validate_singleton
# this is a direct port of the functionality from the PR discussed above, though
# *all* unions are considered "smart" for our purposes.
def wrap_validate_singleton(
self: ModelField,
v: Any,
values: Dict[str, Any],
loc: "LocStr",
cls: Optional["ModelOrDc"],
) -> "ValidateReturn":
if self.sub_fields:
if get_origin(self.type_) is Union:
for field in self.sub_fields:
if v.__class__ is field.outer_type_:
return v, None
for field in self.sub_fields:
try:
if isinstance(v, field.outer_type_):
return v, None
except TypeError:
pass
return upstream_validate_singleton(self, v, values, loc, cls)
ModelField._validate_singleton = wrap_validate_singleton # type: ignore
# this should be included in any file that defines a pydantic model that uses a
# Union and calls to it should be removed when Pydantic's smart union support
# lands
def _check_hotfix() -> None:
if ModelField._validate_singleton != wrap_validate_singleton:
raise Exception("pydantic Union hotfix not applied")

View File

@ -10,6 +10,7 @@ from uuid import UUID, uuid4
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ._monkeypatch import _check_hotfix
from .enums import ( from .enums import (
OS, OS,
Architecture, Architecture,
@ -319,3 +320,6 @@ def parse_event_message(data: Dict[str, Any]) -> EventMessage:
instance_id=instance_id, instance_id=instance_id,
instance_name=instance_name, instance_name=instance_name,
) )
_check_hotfix()

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator from pydantic import BaseModel, Field, root_validator, validator
from ._monkeypatch import _check_hotfix
from .enums import OS, ContainerType, UserFieldOperation, UserFieldType from .enums import OS, ContainerType, UserFieldOperation, UserFieldType
from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers from .models import JobConfig, NotificationConfig, TaskConfig, TaskContainers
from .primitives import File from .primitives import File
@ -184,3 +185,6 @@ class JobTemplateGet(BaseRequest):
class JobTemplateRequestParameters(BaseRequest): class JobTemplateRequestParameters(BaseRequest):
user_fields: TemplateUserFields user_fields: TemplateUserFields
_check_hotfix()

View File

@ -10,6 +10,7 @@ from uuid import UUID, uuid4
from pydantic import BaseModel, Field, root_validator, validator from pydantic import BaseModel, Field, root_validator, validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from ._monkeypatch import _check_hotfix
from .consts import ONE_HOUR, SEVEN_DAYS from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import ( from .enums import (
OS, OS,
@ -842,3 +843,6 @@ class Task(BaseModel):
events: Optional[List[TaskEventSummary]] events: Optional[List[TaskEventSummary]]
nodes: Optional[List[NodeAssignment]] nodes: Optional[List[NodeAssignment]]
user_info: Optional[UserInfo] user_info: Optional[UserInfo]
_check_hotfix()

View File

@ -8,6 +8,7 @@ from uuid import UUID
from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator from pydantic import AnyHttpUrl, BaseModel, Field, root_validator, validator
from ._monkeypatch import _check_hotfix
from .consts import ONE_HOUR, SEVEN_DAYS from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import ( from .enums import (
OS, OS,
@ -250,3 +251,6 @@ class WebhookUpdate(BaseModel):
class NodeAddSshKey(BaseModel): class NodeAddSshKey(BaseModel):
machine_id: UUID machine_id: UUID
public_key: str public_key: str
_check_hotfix()