Storing secrets in azure keyvault (#326)

This commit is contained in:
Cheick Keita
2021-01-25 08:12:07 -08:00
committed by GitHub
parent dc31ffc92b
commit 3f2883d38e
12 changed files with 358 additions and 28 deletions

View File

@ -14,6 +14,7 @@ from typing import (
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
@ -33,11 +34,12 @@ from onefuzztypes.enums import (
UpdateType,
VmState,
)
from onefuzztypes.models import Error
from onefuzztypes.models import Error, SecretData
from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel, Field
from typing_extensions import Protocol
from ..onefuzzlib.secrets import save_to_keyvault
from .azure.table import get_client
from .telemetry import track_event_filtered
from .updates import queue_update
@ -268,18 +270,49 @@ class ORMMixin(ModelMixin):
return (partition_key, row_key)
@classmethod
def hide_secrets(
cls,
model: BaseModel,
hider: Callable[["SecretData"], None],
visited: Set[int] = set(),
) -> None:
if id(model) in visited:
return
visited.add(id(model))
for field in model.__fields__:
field_data = getattr(model, field)
if isinstance(field_data, SecretData):
hider(field_data)
elif isinstance(field_data, List):
if len(field_data) > 0:
if not isinstance(field_data[0], BaseModel):
continue
for data in field_data:
cls.hide_secrets(data, hider, visited)
elif isinstance(field_data, dict):
for key in field_data:
if not isinstance(field_data[key], BaseModel):
continue
cls.hide_secrets(field_data[key], hider, visited)
else:
if isinstance(field_data, BaseModel):
cls.hide_secrets(field_data, hider, visited)
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
self.__class__.hide_secrets(self, save_to_keyvault)
# TODO: migrate to an inspect.signature() model
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
for key in raw:
if not isinstance(raw[key], (str, int)):
raw[key] = json.dumps(raw[key])
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
for field in self.__fields__:
if field not in raw:
continue
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
if self.__fields__[field].type_ == datetime:
raw[field] = getattr(self, field)