Storing secrets in azure keyvault (#326)

This commit is contained in:
Cheick Keita
2021-01-25 08:12:07 -08:00
committed by GitHub
parent dc31ffc92b
commit 3f2883d38e
12 changed files with 358 additions and 28 deletions

View File

@ -17,9 +17,6 @@ from ..onefuzzlib.request import not_ok, ok, parse_request
def get(req: func.HttpRequest) -> func.HttpResponse: def get(req: func.HttpRequest) -> func.HttpResponse:
entries = Notification.search() entries = Notification.search()
for entry in entries:
entry.config.redact()
return ok(entries) return ok(entries)
@ -46,7 +43,6 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(entry, context="notification delete") return not_ok(entry, context="notification delete")
entry.delete() entry.delete()
entry.config.redact()
return ok(entry) return ok(entry)

View File

@ -11,6 +11,8 @@ from azure.cli.core import CLIError
from azure.common.client_factory import get_client_from_cli_profile from azure.common.client_factory import get_client_from_cli_profile
from azure.graphrbac import GraphRbacManagementClient from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from azure.mgmt.resource import ResourceManagementClient from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.subscription import SubscriptionClient from azure.mgmt.subscription import SubscriptionClient
from memoization import cached from memoization import cached
@ -134,3 +136,8 @@ def get_scaleset_principal_id() -> UUID:
client = mgmt_client_factory(ResourceManagementClient) client = mgmt_client_factory(ResourceManagementClient)
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version) uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
return UUID(uid.properties["principalId"]) return UUID(uid.properties["principalId"])
@cached
def get_keyvault_client(vault_url: str) -> SecretClient:
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())

View File

@ -27,6 +27,7 @@ from memoization import cached
from onefuzztypes.models import ADOTemplate, Report from onefuzztypes.models import ADOTemplate, Report
from onefuzztypes.primitives import Container from onefuzztypes.primitives import Container
from ..secrets import get_secret_string_value
from .common import Render, fail_task from .common import Render, fail_task
@ -54,7 +55,8 @@ class ADO:
): ):
self.config = config self.config = config
self.renderer = Render(container, filename, report) self.renderer = Render(container, filename, report)
self.client = get_ado_client(self.config.base_url, self.config.auth_token) auth_token = get_secret_string_value(self.config.auth_token)
self.client = get_ado_client(self.config.base_url, auth_token)
self.project = self.render(self.config.project) self.project = self.render(self.config.project)
def render(self, template: str) -> str: def render(self, template: str) -> str:

View File

@ -10,9 +10,10 @@ from github3 import login
from github3.exceptions import GitHubException from github3.exceptions import GitHubException
from github3.issues import Issue from github3.issues import Issue
from onefuzztypes.enums import GithubIssueSearchMatch from onefuzztypes.enums import GithubIssueSearchMatch
from onefuzztypes.models import GithubIssueTemplate, Report from onefuzztypes.models import GithubAuth, GithubIssueTemplate, Report
from onefuzztypes.primitives import Container from onefuzztypes.primitives import Container
from ..secrets import get_secret_obj
from .common import Render, fail_task from .common import Render, fail_task
@ -26,9 +27,12 @@ class GithubIssue:
): ):
self.config = config self.config = config
self.report = report self.report = report
self.gh = login( if isinstance(config.auth.secret, GithubAuth):
username=config.auth.user, password=config.auth.personal_access_token auth = config.auth.secret
) else:
auth = get_secret_obj(config.auth.secret.url, GithubAuth)
self.gh = login(username=auth.user, password=auth.personal_access_token)
self.renderer = Render(container, filename, report) self.renderer = Render(container, filename, report)
def render(self, field: str) -> str: def render(self, field: str) -> str:

View File

@ -11,6 +11,7 @@ from onefuzztypes.models import Report, TeamsTemplate
from onefuzztypes.primitives import Container from onefuzztypes.primitives import Container
from ..azure.containers import auth_download_url from ..azure.containers import auth_download_url
from ..secrets import get_secret_string_value
from ..tasks.config import get_setup_container from ..tasks.config import get_setup_container
from ..tasks.main import Task from ..tasks.main import Task
@ -46,7 +47,8 @@ def send_teams_webhook(
if text: if text:
message["sections"].append({"text": text}) message["sections"].append({"text": text})
response = requests.post(config.url, json=message) config_url = get_secret_string_value(config.url)
response = requests.post(config_url, json=message)
if not response.ok: if not response.ok:
logging.error("webhook failed %s %s", response.status_code, response.content) logging.error("webhook failed %s %s", response.status_code, response.content)

View File

@ -14,6 +14,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Set,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -33,11 +34,12 @@ from onefuzztypes.enums import (
UpdateType, UpdateType,
VmState, VmState,
) )
from onefuzztypes.models import Error from onefuzztypes.models import Error, SecretData
from onefuzztypes.primitives import Container, PoolName, Region from onefuzztypes.primitives import Container, PoolName, Region
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol from typing_extensions import Protocol
from ..onefuzzlib.secrets import save_to_keyvault
from .azure.table import get_client from .azure.table import get_client
from .telemetry import track_event_filtered from .telemetry import track_event_filtered
from .updates import queue_update from .updates import queue_update
@ -268,18 +270,49 @@ class ORMMixin(ModelMixin):
return (partition_key, row_key) return (partition_key, row_key)
@classmethod
def hide_secrets(
cls,
model: BaseModel,
hider: Callable[["SecretData"], None],
visited: Set[int] = set(),
) -> None:
if id(model) in visited:
return
visited.add(id(model))
for field in model.__fields__:
field_data = getattr(model, field)
if isinstance(field_data, SecretData):
hider(field_data)
elif isinstance(field_data, List):
if len(field_data) > 0:
if not isinstance(field_data[0], BaseModel):
continue
for data in field_data:
cls.hide_secrets(data, hider, visited)
elif isinstance(field_data, dict):
for key in field_data:
if not isinstance(field_data[key], BaseModel):
continue
cls.hide_secrets(field_data[key], hider, visited)
else:
if isinstance(field_data, BaseModel):
cls.hide_secrets(field_data, hider, visited)
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]: def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
self.__class__.hide_secrets(self, save_to_keyvault)
# TODO: migrate to an inspect.signature() model # TODO: migrate to an inspect.signature() model
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude()) raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
for key in raw: for key in raw:
if not isinstance(raw[key], (str, int)): if not isinstance(raw[key], (str, int)):
raw[key] = json.dumps(raw[key]) raw[key] = json.dumps(raw[key])
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
for field in self.__fields__: for field in self.__fields__:
if field not in raw: if field not in raw:
continue continue
# for datetime fields that passed through filtering, use the real value,
# rather than a serialized form
if self.__fields__[field].type_ == datetime: if self.__fields__[field].type_ == datetime:
raw[field] = getattr(self, field) raw[field] = getattr(self, field)

View File

@ -0,0 +1,79 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Tuple, Type, TypeVar, cast
from urllib.parse import urlparse
from uuid import uuid4
from azure.keyvault.secrets import KeyVaultSecret
from onefuzztypes.models import SecretAddress, SecretData
from pydantic import BaseModel
from .azure.creds import get_instance_name, get_keyvault_client
A = TypeVar("A", bound=BaseModel)
def save_to_keyvault(secret_data: SecretData) -> None:
if isinstance(secret_data.secret, SecretAddress):
return
secret_name = str(uuid4())
if isinstance(secret_data.secret, str):
secret_value = secret_data.secret
elif isinstance(secret_data.secret, BaseModel):
secret_value = secret_data.secret.json()
else:
raise Exception("invalid secret data")
kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
secret_data.secret = SecretAddress(url=kv.id)
def get_secret_string_value(self: SecretData[str]) -> str:
if isinstance(self.secret, SecretAddress):
secret = get_secret(self.secret.url)
return cast(str, secret.value)
else:
return self.secret
def get_keyvault_address() -> str:
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
return f"https://{get_instance_name()}-vault.vault.azure.net"
def store_in_keyvault(
keyvault_url: str, secret_name: str, secret_value: str
) -> KeyVaultSecret:
keyvault_client = get_keyvault_client(keyvault_url)
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
return kvs
def parse_secret_url(secret_url: str) -> Tuple[str, str]:
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
u = urlparse(secret_url)
vault_url = f"{u.scheme}://{u.netloc}"
secret_name = u.path.split("/")[2]
return (vault_url, secret_name)
def get_secret(secret_url: str) -> KeyVaultSecret:
(vault_url, secret_name) = parse_secret_url(secret_url)
keyvault_client = get_keyvault_client(vault_url)
return keyvault_client.get_secret(secret_name)
def get_secret_obj(secret_url: str, model: Type[A]) -> A:
secret = get_secret(secret_url)
return model.parse_raw(secret.value)
def delete_secret(secret_url: str) -> None:
(vault_url, secret_name) = parse_secret_url(secret_url)
keyvault_client = get_keyvault_client(vault_url)
keyvault_client.begin_delete_secret(secret_name).wait()

View File

@ -0,0 +1,91 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import json
import unittest
from onefuzztypes.enums import OS, ContainerType
from onefuzztypes.job_templates import (
JobTemplate,
JobTemplateIndex,
JobTemplateNotification,
)
from onefuzztypes.models import (
JobConfig,
Notification,
NotificationConfig,
SecretAddress,
SecretData,
TeamsTemplate,
)
from onefuzztypes.primitives import Container
from __app__.onefuzzlib.orm import ORMMixin
class TestSecret(unittest.TestCase):
def test_hide(self) -> None:
def hider(secret_data: SecretData) -> None:
if not isinstance(secret_data.secret, SecretAddress):
secret_data.secret = SecretAddress(url="blah blah")
notification = Notification(
container=Container("data"),
config=TeamsTemplate(url=SecretData(secret="http://test")),
)
ORMMixin.hide_secrets(notification, hider)
if isinstance(notification.config, TeamsTemplate):
self.assertIsInstance(notification.config.url, SecretData)
self.assertIsInstance(notification.config.url.secret, SecretAddress)
else:
self.fail(f"Invalid config type {type(notification.config)}")
def test_hide_nested_list(self) -> None:
def hider(secret_data: SecretData) -> None:
if not isinstance(secret_data.secret, SecretAddress):
secret_data.secret = SecretAddress(url="blah blah")
job_template_index = JobTemplateIndex(
name="test",
template=JobTemplate(
os=OS.linux,
job=JobConfig(name="test", build="test", project="test", duration=1),
tasks=[],
notifications=[
JobTemplateNotification(
container_type=ContainerType.unique_inputs,
notification=NotificationConfig(
config=TeamsTemplate(url=SecretData(secret="http://test"))
),
)
],
user_fields=[],
),
)
ORMMixin.hide_secrets(job_template_index, hider)
notification = job_template_index.template.notifications[0].notification
if isinstance(notification.config, TeamsTemplate):
self.assertIsInstance(notification.config.url, SecretData)
self.assertIsInstance(notification.config.url.secret, SecretAddress)
else:
self.fail(f"Invalid config type {type(notification.config)}")
def test_read_secret(self) -> None:
json_data = """
{
"notification_id": "b52b24d1-eec6-46c9-b06a-818a997da43c",
"container": "data",
"config" : {"url": {"secret": {"url": "http://test"}}}
}
"""
data = json.loads(json_data)
notification = Notification.parse_obj(data)
self.assertIsInstance(notification.config, TeamsTemplate)
if isinstance(notification.config, TeamsTemplate):
self.assertIsInstance(notification.config.url, SecretData)
self.assertIsInstance(notification.config.url.secret, SecretAddress)
else:
self.fail(f"Invalid config type {type(notification.config)}")

View File

@ -11,6 +11,7 @@ import logging
import os import os
import sys import sys
import time import time
from dataclasses import asdict, is_dataclass
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
@ -381,6 +382,8 @@ def serialize(data: Any) -> Any:
return str(data) return str(data)
if isinstance(data, (int, str)): if isinstance(data, (int, str)):
return data return data
if is_dataclass(data):
return {serialize(a): serialize(b) for (a, b) in asdict(data).items()}
raise Exception("unknown type %s" % type(data)) raise Exception("unknown type %s" % type(data))

View File

@ -56,7 +56,8 @@
"Managed Identity Operator": "f1a07417-d97a-45cb-824c-7a7467783830", "Managed Identity Operator": "f1a07417-d97a-45cb-824c-7a7467783830",
"Network Contributor": "4d97b98b-1d4f-4787-a291-c67834d212e7", "Network Contributor": "4d97b98b-1d4f-4787-a291-c67834d212e7",
"Storage Account Contributor": "17d1049b-9a84-46fb-8f53-869881c3d3ab", "Storage Account Contributor": "17d1049b-9a84-46fb-8f53-869881c3d3ab",
"Virtual Machine Contributor": "9980e02c-c2be-4d73-94e8-173b1dc7cf3c" "Virtual Machine Contributor": "9980e02c-c2be-4d73-94e8-173b1dc7cf3c",
"keyVaultName": "[concat(parameters('name'), '-vault')]"
}, },
"functions": [ "functions": [
{ {
@ -101,6 +102,34 @@
"apiVersion": "2018-11-30", "apiVersion": "2018-11-30",
"location": "[resourceGroup().location]" "location": "[resourceGroup().location]"
}, },
{
"type": "Microsoft.KeyVault/vaults",
"apiVersion": "2019-09-01",
"name": "[variables('keyVaultName')]",
"location": "[resourceGroup().location]",
"properties": {
"enabledForDiskEncryption": false,
"enabledForTemplateDeployment": true,
"tenantId": "[subscription().tenantId]",
"accessPolicies": [
{
"objectId": "[reference(resourceId('Microsoft.Web/sites', parameters('name')), '2019-08-01', 'full').identity.principalId]",
"tenantId": "[subscription().tenantId]",
"permissions": {
"secrets": ["get", "list", "set", "delete"]
}
}
],
"sku": {
"name": "standard",
"family": "A"
},
"networkAcls": {
"defaultAction": "Allow",
"bypass": "AzureServices"
}
}
},
{ {
"apiVersion": "2018-11-01", "apiVersion": "2018-11-01",
"name": "[parameters('name')]", "name": "[parameters('name')]",

View File

@ -4,10 +4,11 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from pydantic import BaseModel, Field, root_validator, validator from pydantic import BaseModel, Field, root_validator, validator
from pydantic.dataclasses import dataclass
from .consts import ONE_HOUR, SEVEN_DAYS from .consts import ONE_HOUR, SEVEN_DAYS
from .enums import ( from .enums import (
@ -41,6 +42,39 @@ class UserInfo(BaseModel):
upn: Optional[str] upn: Optional[str]
# Stores the address of a secret
class SecretAddress(BaseModel):
# keyvault address of a secret
url: str
T = TypeVar("T")
# This class allows us to store some data that are intended to be secret
# The secret field stores either the raw data or the address of that data
# This class allows us to maintain backward compatibility with existing
# NotificationTemplate classes
@dataclass
class SecretData(Generic[T]):
secret: Union[T, SecretAddress]
def __init__(self, secret: Union[T, SecretAddress]):
if isinstance(secret, dict):
self.secret = SecretAddress.parse_obj(secret)
else:
self.secret = secret
def __str__(self) -> str:
return self.__repr__()
def __repr__(self) -> str:
if isinstance(self.secret, SecretAddress):
return str(self.secret)
else:
return "[REDACTED]"
class EnumModel(BaseModel): class EnumModel(BaseModel):
@root_validator(pre=True) @root_validator(pre=True)
def exactly_one(cls: Any, values: Any) -> Any: def exactly_one(cls: Any, values: Any) -> Any:
@ -226,7 +260,7 @@ class ADODuplicateTemplate(BaseModel):
class ADOTemplate(BaseModel): class ADOTemplate(BaseModel):
base_url: str base_url: str
auth_token: str auth_token: SecretData[str]
project: str project: str
type: str type: str
unique_fields: List[str] unique_fields: List[str]
@ -234,15 +268,33 @@ class ADOTemplate(BaseModel):
ado_fields: Dict[str, str] ado_fields: Dict[str, str]
on_duplicate: ADODuplicateTemplate on_duplicate: ADODuplicateTemplate
def redact(self) -> None: # validator needed for backward compatibility
self.auth_token = "***" @validator("auth_token", pre=True, always=True)
def validate_auth_token(cls, v: Any) -> SecretData:
if isinstance(v, str):
return SecretData(secret=v)
elif isinstance(v, SecretData):
return v
elif isinstance(v, dict):
return SecretData(secret=v["secret"])
else:
raise TypeError(f"invalid datatype {type(v)}")
class TeamsTemplate(BaseModel): class TeamsTemplate(BaseModel):
url: str url: SecretData[str]
def redact(self) -> None: # validator needed for backward compatibility
self.url = "***" @validator("url", pre=True, always=True)
def validate_url(cls, v: Any) -> SecretData:
if isinstance(v, str):
return SecretData(secret=v)
elif isinstance(v, SecretData):
return v
elif isinstance(v, dict):
return SecretData(secret=v["secret"])
else:
raise TypeError(f"invalid datatype {type(v)}")
class ContainerDefinition(BaseModel): class ContainerDefinition(BaseModel):
@ -408,7 +460,7 @@ class GithubAuth(BaseModel):
class GithubIssueTemplate(BaseModel): class GithubIssueTemplate(BaseModel):
auth: GithubAuth auth: SecretData[GithubAuth]
organization: str organization: str
repository: str repository: str
title: str title: str
@ -418,9 +470,20 @@ class GithubIssueTemplate(BaseModel):
labels: List[str] labels: List[str]
on_duplicate: GithubIssueDuplicate on_duplicate: GithubIssueDuplicate
def redact(self) -> None: # validator needed for backward compatibility
self.auth.user = "***" @validator("auth", pre=True, always=True)
self.auth.personal_access_token = "***" def validate_auth(cls, v: Any) -> SecretData:
if isinstance(v, str):
return SecretData(secret=v)
elif isinstance(v, SecretData):
return v
elif isinstance(v, dict):
try:
return SecretData(GithubAuth.parse_obj(v))
except Exception:
return SecretData(secret=v["secret"])
else:
raise TypeError(f"invalid datatype {type(v)}")
NotificationTemplate = Union[ADOTemplate, TeamsTemplate, GithubIssueTemplate] NotificationTemplate = Union[ADOTemplate, TeamsTemplate, GithubIssueTemplate]

View File

@ -5,14 +5,34 @@
import unittest import unittest
from pydantic import ValidationError from onefuzztypes.models import Scaleset, SecretData, TeamsTemplate
from onefuzztypes.models import Scaleset, TeamsTemplate
from onefuzztypes.requests import NotificationCreate from onefuzztypes.requests import NotificationCreate
from pydantic import ValidationError
class TestModelsVerify(unittest.TestCase): class TestModelsVerify(unittest.TestCase):
def test_model(self) -> None: def test_model(self) -> None:
data = {
"container": "data",
"config": {"url": {"secret": "https://www.contoso.com/"}},
}
notification = NotificationCreate.parse_obj(data)
self.assertIsInstance(notification.config, TeamsTemplate)
self.assertIsInstance(notification.config.url, SecretData)
self.assertEqual(
notification.config.url.secret,
"https://www.contoso.com/",
"mismatch secret value",
)
missing_container = {
"config": {"url": "https://www.contoso.com/"},
}
with self.assertRaises(ValidationError):
NotificationCreate.parse_obj(missing_container)
def test_legacy_model(self) -> None:
data = { data = {
"container": "data", "container": "data",
"config": {"url": "https://www.contoso.com/"}, "config": {"url": "https://www.contoso.com/"},
@ -20,6 +40,7 @@ class TestModelsVerify(unittest.TestCase):
notification = NotificationCreate.parse_obj(data) notification = NotificationCreate.parse_obj(data)
self.assertIsInstance(notification.config, TeamsTemplate) self.assertIsInstance(notification.config, TeamsTemplate)
self.assertIsInstance(notification.config.url, SecretData)
missing_container = { missing_container = {
"config": {"url": "https://www.contoso.com/"}, "config": {"url": "https://www.contoso.com/"},