mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-21 13:51:19 +00:00
481 lines
14 KiB
Python
481 lines
14 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import inspect
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
from uuid import UUID
|
|
|
|
from azure.common import AzureConflictHttpError, AzureMissingResourceHttpError
|
|
from onefuzztypes.enums import (
|
|
ErrorCode,
|
|
JobState,
|
|
NodeState,
|
|
PoolState,
|
|
ScalesetState,
|
|
TaskState,
|
|
TelemetryEvent,
|
|
UpdateType,
|
|
VmState,
|
|
)
|
|
from onefuzztypes.models import Error, SecretData
|
|
from onefuzztypes.primitives import Container, PoolName, Region
|
|
from pydantic import BaseModel, Field
|
|
from typing_extensions import Protocol
|
|
|
|
from .azure.table import get_client
|
|
from .secrets import save_to_keyvault
|
|
from .telemetry import track_event_filtered
|
|
from .updates import queue_update
|
|
|
|
A = TypeVar("A", bound="ORMMixin")
|
|
|
|
QUERY_VALUE_TYPES = Union[
|
|
List[bool],
|
|
List[int],
|
|
List[str],
|
|
List[UUID],
|
|
List[Region],
|
|
List[Container],
|
|
List[PoolName],
|
|
List[VmState],
|
|
List[ScalesetState],
|
|
List[JobState],
|
|
List[TaskState],
|
|
List[PoolState],
|
|
List[NodeState],
|
|
]
|
|
QueryFilter = Dict[str, QUERY_VALUE_TYPES]
|
|
|
|
SAFE_STRINGS = (UUID, Container, Region, PoolName)
|
|
KEY = Union[int, str, UUID, Enum]
|
|
|
|
HOURS = 60 * 60
|
|
|
|
|
|
class HasState(Protocol):
|
|
# TODO: this should be bound tighter than Any
|
|
# In the end, we want this to be an Enum. Specifically, one of
|
|
# the JobState,TaskState,etc enums.
|
|
state: Any
|
|
|
|
|
|
def process_state_update(obj: HasState) -> None:
|
|
"""
|
|
process a single state update, if the obj
|
|
implements a function for that state
|
|
"""
|
|
|
|
func = getattr(obj, obj.state.name, None)
|
|
if func is None:
|
|
return
|
|
|
|
get_keys = getattr(obj, "get_keys", None)
|
|
if get_keys is not None:
|
|
logging.info("processing state update: %s - %s", get_keys(), obj.state.name)
|
|
|
|
func()
|
|
|
|
|
|
def process_state_updates(obj: HasState, max_updates: int = 5) -> None:
|
|
""" process through the state machine for an object """
|
|
|
|
for _ in range(max_updates):
|
|
state = obj.state
|
|
process_state_update(obj)
|
|
new_state = obj.state
|
|
if new_state == state:
|
|
break
|
|
|
|
|
|
def resolve(key: KEY) -> str:
|
|
if isinstance(key, str):
|
|
return key
|
|
elif isinstance(key, UUID):
|
|
return str(key)
|
|
elif isinstance(key, Enum):
|
|
return key.name
|
|
elif isinstance(key, int):
|
|
return str(key)
|
|
raise NotImplementedError("unsupported type %s - %s" % (type(key), repr(key)))
|
|
|
|
|
|
def build_filters(
|
|
cls: Type[A], query_args: Optional[QueryFilter]
|
|
) -> Tuple[Optional[str], QueryFilter]:
|
|
if not query_args:
|
|
return (None, {})
|
|
|
|
partition_key_field, row_key_field = cls.key_fields()
|
|
|
|
search_filter_parts = []
|
|
post_filters: QueryFilter = {}
|
|
|
|
for field, values in query_args.items():
|
|
if field not in cls.__fields__:
|
|
raise ValueError("unexpected field %s: %s" % (repr(field), cls))
|
|
|
|
if not values:
|
|
continue
|
|
|
|
if field == partition_key_field:
|
|
field_name = "PartitionKey"
|
|
elif field == row_key_field:
|
|
field_name = "RowKey"
|
|
else:
|
|
field_name = field
|
|
|
|
parts: Optional[List[str]] = None
|
|
if isinstance(values[0], bool):
|
|
parts = []
|
|
for x in values:
|
|
if not isinstance(x, bool):
|
|
raise TypeError("unexpected type")
|
|
parts.append("%s eq %s" % (field_name, str(x).lower()))
|
|
elif isinstance(values[0], int):
|
|
parts = []
|
|
for x in values:
|
|
if not isinstance(x, int):
|
|
raise TypeError("unexpected type")
|
|
parts.append("%s eq %d" % (field_name, x))
|
|
elif isinstance(values[0], Enum):
|
|
parts = []
|
|
for x in values:
|
|
if not isinstance(x, Enum):
|
|
raise TypeError("unexpected type")
|
|
parts.append("%s eq '%s'" % (field_name, x.name))
|
|
elif all(isinstance(x, SAFE_STRINGS) for x in values):
|
|
parts = ["%s eq '%s'" % (field_name, x) for x in values]
|
|
else:
|
|
post_filters[field_name] = values
|
|
|
|
if parts:
|
|
if len(parts) == 1:
|
|
search_filter_parts.append(parts[0])
|
|
else:
|
|
search_filter_parts.append("(" + " or ".join(parts) + ")")
|
|
|
|
if search_filter_parts:
|
|
return (" and ".join(search_filter_parts), post_filters)
|
|
|
|
return (None, post_filters)
|
|
|
|
|
|
def post_filter(value: Any, filters: Optional[QueryFilter]) -> bool:
|
|
if not filters:
|
|
return True
|
|
|
|
for field in filters:
|
|
if field not in value:
|
|
return False
|
|
if value[field] not in filters[field]:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
MappingIntStrAny = Mapping[Union[int, str], Any]
|
|
# A = TypeVar("A", bound="Model")
|
|
|
|
|
|
class ModelMixin(BaseModel):
|
|
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
|
return None
|
|
|
|
def raw(
|
|
self,
|
|
*,
|
|
by_alias: bool = False,
|
|
exclude_none: bool = False,
|
|
exclude: MappingIntStrAny = None,
|
|
include: MappingIntStrAny = None,
|
|
) -> Dict[str, Any]:
|
|
# cycling through json means all wrapped types get resolved, such as UUID
|
|
result: Dict[str, Any] = json.loads(
|
|
self.json(
|
|
by_alias=by_alias,
|
|
exclude_none=exclude_none,
|
|
exclude=exclude,
|
|
include=include,
|
|
)
|
|
)
|
|
return result
|
|
|
|
|
|
class ORMMixin(ModelMixin):
|
|
Timestamp: Optional[datetime] = Field(alias="Timestamp")
|
|
etag: Optional[str]
|
|
|
|
@classmethod
|
|
def table_name(cls: Type[A]) -> str:
|
|
return cls.__name__
|
|
|
|
@classmethod
|
|
def get(
|
|
cls: Type[A], PartitionKey: KEY, RowKey: Optional[KEY] = None
|
|
) -> Optional[A]:
|
|
client = get_client(table=cls.table_name())
|
|
partition_key = resolve(PartitionKey)
|
|
row_key = resolve(RowKey) if RowKey else partition_key
|
|
|
|
try:
|
|
raw = client.get_entity(cls.table_name(), partition_key, row_key)
|
|
except AzureMissingResourceHttpError:
|
|
return None
|
|
return cls.load(raw)
|
|
|
|
@classmethod
|
|
def key_fields(cls) -> Tuple[str, Optional[str]]:
|
|
raise NotImplementedError("keys not defined")
|
|
|
|
# FILTERS:
|
|
# The following
|
|
# * save_exclude: Specify fields to *exclude* from saving to Storage Tables
|
|
# * export_exclude: Specify the fields to *exclude* from sending to an external API
|
|
# * telemetry_include: Specify the fields to *include* for telemetry
|
|
#
|
|
# For implementation details see:
|
|
# https://pydantic-docs.helpmanual.io/usage/exporting_models/#advanced-include-and-exclude
|
|
def save_exclude(self) -> Optional[MappingIntStrAny]:
|
|
return None
|
|
|
|
def export_exclude(self) -> Optional[MappingIntStrAny]:
|
|
return {"etag": ..., "Timestamp": ...}
|
|
|
|
def telemetry_include(self) -> Optional[MappingIntStrAny]:
|
|
return {}
|
|
|
|
def telemetry(self) -> Any:
|
|
return self.raw(exclude_none=True, include=self.telemetry_include())
|
|
|
|
def get_keys(self) -> Tuple[KEY, KEY]:
|
|
partition_key_field, row_key_field = self.key_fields()
|
|
|
|
partition_key = getattr(self, partition_key_field)
|
|
if row_key_field:
|
|
row_key = getattr(self, row_key_field)
|
|
else:
|
|
row_key = partition_key
|
|
|
|
return (partition_key, row_key)
|
|
|
|
@classmethod
|
|
def hide_secrets(
|
|
cls,
|
|
model: BaseModel,
|
|
hider: Callable[["SecretData"], None],
|
|
visited: Set[int] = set(),
|
|
) -> None:
|
|
if id(model) in visited:
|
|
return
|
|
|
|
visited.add(id(model))
|
|
for field in model.__fields__:
|
|
field_data = getattr(model, field)
|
|
if isinstance(field_data, SecretData):
|
|
hider(field_data)
|
|
elif isinstance(field_data, List):
|
|
if len(field_data) > 0:
|
|
if not isinstance(field_data[0], BaseModel):
|
|
continue
|
|
for data in field_data:
|
|
cls.hide_secrets(data, hider, visited)
|
|
elif isinstance(field_data, dict):
|
|
for key in field_data:
|
|
if not isinstance(field_data[key], BaseModel):
|
|
continue
|
|
cls.hide_secrets(field_data[key], hider, visited)
|
|
else:
|
|
if isinstance(field_data, BaseModel):
|
|
cls.hide_secrets(field_data, hider, visited)
|
|
|
|
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
|
|
self.__class__.hide_secrets(self, save_to_keyvault)
|
|
# TODO: migrate to an inspect.signature() model
|
|
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
|
|
for key in raw:
|
|
if not isinstance(raw[key], (str, int)):
|
|
raw[key] = json.dumps(raw[key])
|
|
|
|
for field in self.__fields__:
|
|
if field not in raw:
|
|
continue
|
|
# for datetime fields that passed through filtering, use the real value,
|
|
# rather than a serialized form
|
|
if self.__fields__[field].type_ == datetime:
|
|
raw[field] = getattr(self, field)
|
|
|
|
partition_key_field, row_key_field = self.key_fields()
|
|
|
|
# PartitionKey and RowKey must be 'str'
|
|
raw["PartitionKey"] = resolve(raw[partition_key_field])
|
|
raw["RowKey"] = resolve(raw[row_key_field or partition_key_field])
|
|
|
|
del raw[partition_key_field]
|
|
if row_key_field in raw:
|
|
del raw[row_key_field]
|
|
|
|
client = get_client(table=self.table_name())
|
|
|
|
# never save the timestamp
|
|
if "Timestamp" in raw:
|
|
del raw["Timestamp"]
|
|
|
|
if new:
|
|
try:
|
|
self.etag = client.insert_entity(self.table_name(), raw)
|
|
except AzureConflictHttpError:
|
|
return Error(
|
|
code=ErrorCode.UNABLE_TO_CREATE, errors=["entry already exists"]
|
|
)
|
|
elif self.etag and require_etag:
|
|
self.etag = client.replace_entity(
|
|
self.table_name(), raw, if_match=self.etag
|
|
)
|
|
else:
|
|
self.etag = client.insert_or_replace_entity(self.table_name(), raw)
|
|
|
|
if self.table_name() in TelemetryEvent.__members__:
|
|
telem = self.telemetry()
|
|
if telem:
|
|
track_event_filtered(TelemetryEvent[self.table_name()], telem)
|
|
|
|
return None
|
|
|
|
def delete(self) -> None:
|
|
partition_key, row_key = self.get_keys()
|
|
|
|
client = get_client()
|
|
try:
|
|
client.delete_entity(
|
|
self.table_name(), resolve(partition_key), resolve(row_key)
|
|
)
|
|
except AzureMissingResourceHttpError:
|
|
# It's OK if the component is already deleted
|
|
pass
|
|
|
|
@classmethod
|
|
def load(cls: Type[A], data: Dict[str, Union[str, bytes, bytearray]]) -> A:
|
|
partition_key_field, row_key_field = cls.key_fields()
|
|
|
|
if partition_key_field in data:
|
|
raise Exception(
|
|
"duplicate PartitionKey field %s for %s"
|
|
% (partition_key_field, cls.table_name())
|
|
)
|
|
if row_key_field in data:
|
|
raise Exception(
|
|
"duplicate RowKey field %s for %s" % (row_key_field, cls.table_name())
|
|
)
|
|
|
|
data[partition_key_field] = data["PartitionKey"]
|
|
if row_key_field is not None:
|
|
data[row_key_field] = data["RowKey"]
|
|
|
|
del data["PartitionKey"]
|
|
del data["RowKey"]
|
|
|
|
for key in inspect.signature(cls).parameters:
|
|
if key not in data:
|
|
continue
|
|
|
|
annotation = inspect.signature(cls).parameters[key].annotation
|
|
|
|
if inspect.isclass(annotation):
|
|
if (
|
|
issubclass(annotation, BaseModel)
|
|
or issubclass(annotation, dict)
|
|
or issubclass(annotation, list)
|
|
):
|
|
data[key] = json.loads(data[key])
|
|
continue
|
|
|
|
if getattr(annotation, "__origin__", None) == Union and any(
|
|
inspect.isclass(x) and issubclass(x, BaseModel)
|
|
for x in annotation.__args__
|
|
):
|
|
data[key] = json.loads(data[key])
|
|
continue
|
|
|
|
# Required for Python >=3.7. In 3.6, a `Dict[_,_]` and `List[_]` annotations
|
|
# are a class according to `inspect.isclass`.
|
|
if getattr(annotation, "__origin__", None) in [dict, list]:
|
|
data[key] = json.loads(data[key])
|
|
continue
|
|
|
|
return cls.parse_obj(data)
|
|
|
|
@classmethod
|
|
def search(
|
|
cls: Type[A],
|
|
*,
|
|
query: Optional[QueryFilter] = None,
|
|
raw_unchecked_filter: Optional[str] = None,
|
|
num_results: int = None,
|
|
) -> List[A]:
|
|
search_filter, post_filters = build_filters(cls, query)
|
|
|
|
if raw_unchecked_filter is not None:
|
|
if search_filter is None:
|
|
search_filter = raw_unchecked_filter
|
|
else:
|
|
search_filter = "(%s) and (%s)" % (search_filter, raw_unchecked_filter)
|
|
|
|
client = get_client(table=cls.table_name())
|
|
entries = []
|
|
for row in client.query_entities(
|
|
cls.table_name(), filter=search_filter, num_results=num_results
|
|
):
|
|
if not post_filter(row, post_filters):
|
|
continue
|
|
|
|
entry = cls.load(row)
|
|
entries.append(entry)
|
|
return entries
|
|
|
|
def queue(
|
|
self,
|
|
*,
|
|
method: Optional[Callable] = None,
|
|
visibility_timeout: Optional[int] = None,
|
|
) -> None:
|
|
if not hasattr(self, "state"):
|
|
raise NotImplementedError("Queued an ORM mapping without State")
|
|
|
|
update_type = UpdateType.__members__.get(type(self).__name__)
|
|
if update_type is None:
|
|
raise NotImplementedError("unsupported update type: %s" % self)
|
|
|
|
method_name: Optional[str] = None
|
|
if method is not None:
|
|
if not hasattr(method, "__name__"):
|
|
raise Exception("unable to queue method: %s" % method)
|
|
method_name = method.__name__
|
|
|
|
partition_key, row_key = self.get_keys()
|
|
|
|
queue_update(
|
|
update_type,
|
|
resolve(partition_key),
|
|
resolve(row_key),
|
|
method=method_name,
|
|
visibility_timeout=visibility_timeout,
|
|
)
|