mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-15 03:18:07 +00:00
Storing secrets in azure keyvault (#326)
This commit is contained in:
@ -17,9 +17,6 @@ from ..onefuzzlib.request import not_ok, ok, parse_request
|
|||||||
|
|
||||||
def get(req: func.HttpRequest) -> func.HttpResponse:
|
def get(req: func.HttpRequest) -> func.HttpResponse:
|
||||||
entries = Notification.search()
|
entries = Notification.search()
|
||||||
for entry in entries:
|
|
||||||
entry.config.redact()
|
|
||||||
|
|
||||||
return ok(entries)
|
return ok(entries)
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +43,6 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
|
|||||||
return not_ok(entry, context="notification delete")
|
return not_ok(entry, context="notification delete")
|
||||||
|
|
||||||
entry.delete()
|
entry.delete()
|
||||||
entry.config.redact()
|
|
||||||
return ok(entry)
|
return ok(entry)
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,6 +11,8 @@ from azure.cli.core import CLIError
|
|||||||
from azure.common.client_factory import get_client_from_cli_profile
|
from azure.common.client_factory import get_client_from_cli_profile
|
||||||
from azure.graphrbac import GraphRbacManagementClient
|
from azure.graphrbac import GraphRbacManagementClient
|
||||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
from azure.keyvault.secrets import SecretClient
|
||||||
from azure.mgmt.resource import ResourceManagementClient
|
from azure.mgmt.resource import ResourceManagementClient
|
||||||
from azure.mgmt.subscription import SubscriptionClient
|
from azure.mgmt.subscription import SubscriptionClient
|
||||||
from memoization import cached
|
from memoization import cached
|
||||||
@ -134,3 +136,8 @@ def get_scaleset_principal_id() -> UUID:
|
|||||||
client = mgmt_client_factory(ResourceManagementClient)
|
client = mgmt_client_factory(ResourceManagementClient)
|
||||||
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
|
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
|
||||||
return UUID(uid.properties["principalId"])
|
return UUID(uid.properties["principalId"])
|
||||||
|
|
||||||
|
|
||||||
|
@cached
|
||||||
|
def get_keyvault_client(vault_url: str) -> SecretClient:
|
||||||
|
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
||||||
|
@ -27,6 +27,7 @@ from memoization import cached
|
|||||||
from onefuzztypes.models import ADOTemplate, Report
|
from onefuzztypes.models import ADOTemplate, Report
|
||||||
from onefuzztypes.primitives import Container
|
from onefuzztypes.primitives import Container
|
||||||
|
|
||||||
|
from ..secrets import get_secret_string_value
|
||||||
from .common import Render, fail_task
|
from .common import Render, fail_task
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +55,8 @@ class ADO:
|
|||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.renderer = Render(container, filename, report)
|
self.renderer = Render(container, filename, report)
|
||||||
self.client = get_ado_client(self.config.base_url, self.config.auth_token)
|
auth_token = get_secret_string_value(self.config.auth_token)
|
||||||
|
self.client = get_ado_client(self.config.base_url, auth_token)
|
||||||
self.project = self.render(self.config.project)
|
self.project = self.render(self.config.project)
|
||||||
|
|
||||||
def render(self, template: str) -> str:
|
def render(self, template: str) -> str:
|
||||||
|
@ -10,9 +10,10 @@ from github3 import login
|
|||||||
from github3.exceptions import GitHubException
|
from github3.exceptions import GitHubException
|
||||||
from github3.issues import Issue
|
from github3.issues import Issue
|
||||||
from onefuzztypes.enums import GithubIssueSearchMatch
|
from onefuzztypes.enums import GithubIssueSearchMatch
|
||||||
from onefuzztypes.models import GithubIssueTemplate, Report
|
from onefuzztypes.models import GithubAuth, GithubIssueTemplate, Report
|
||||||
from onefuzztypes.primitives import Container
|
from onefuzztypes.primitives import Container
|
||||||
|
|
||||||
|
from ..secrets import get_secret_obj
|
||||||
from .common import Render, fail_task
|
from .common import Render, fail_task
|
||||||
|
|
||||||
|
|
||||||
@ -26,9 +27,12 @@ class GithubIssue:
|
|||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.report = report
|
self.report = report
|
||||||
self.gh = login(
|
if isinstance(config.auth.secret, GithubAuth):
|
||||||
username=config.auth.user, password=config.auth.personal_access_token
|
auth = config.auth.secret
|
||||||
)
|
else:
|
||||||
|
auth = get_secret_obj(config.auth.secret.url, GithubAuth)
|
||||||
|
|
||||||
|
self.gh = login(username=auth.user, password=auth.personal_access_token)
|
||||||
self.renderer = Render(container, filename, report)
|
self.renderer = Render(container, filename, report)
|
||||||
|
|
||||||
def render(self, field: str) -> str:
|
def render(self, field: str) -> str:
|
||||||
|
@ -11,6 +11,7 @@ from onefuzztypes.models import Report, TeamsTemplate
|
|||||||
from onefuzztypes.primitives import Container
|
from onefuzztypes.primitives import Container
|
||||||
|
|
||||||
from ..azure.containers import auth_download_url
|
from ..azure.containers import auth_download_url
|
||||||
|
from ..secrets import get_secret_string_value
|
||||||
from ..tasks.config import get_setup_container
|
from ..tasks.config import get_setup_container
|
||||||
from ..tasks.main import Task
|
from ..tasks.main import Task
|
||||||
|
|
||||||
@ -46,7 +47,8 @@ def send_teams_webhook(
|
|||||||
if text:
|
if text:
|
||||||
message["sections"].append({"text": text})
|
message["sections"].append({"text": text})
|
||||||
|
|
||||||
response = requests.post(config.url, json=message)
|
config_url = get_secret_string_value(config.url)
|
||||||
|
response = requests.post(config_url, json=message)
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
logging.error("webhook failed %s %s", response.status_code, response.content)
|
logging.error("webhook failed %s %s", response.status_code, response.content)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -33,11 +34,12 @@ from onefuzztypes.enums import (
|
|||||||
UpdateType,
|
UpdateType,
|
||||||
VmState,
|
VmState,
|
||||||
)
|
)
|
||||||
from onefuzztypes.models import Error
|
from onefuzztypes.models import Error, SecretData
|
||||||
from onefuzztypes.primitives import Container, PoolName, Region
|
from onefuzztypes.primitives import Container, PoolName, Region
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
from ..onefuzzlib.secrets import save_to_keyvault
|
||||||
from .azure.table import get_client
|
from .azure.table import get_client
|
||||||
from .telemetry import track_event_filtered
|
from .telemetry import track_event_filtered
|
||||||
from .updates import queue_update
|
from .updates import queue_update
|
||||||
@ -268,18 +270,49 @@ class ORMMixin(ModelMixin):
|
|||||||
|
|
||||||
return (partition_key, row_key)
|
return (partition_key, row_key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def hide_secrets(
|
||||||
|
cls,
|
||||||
|
model: BaseModel,
|
||||||
|
hider: Callable[["SecretData"], None],
|
||||||
|
visited: Set[int] = set(),
|
||||||
|
) -> None:
|
||||||
|
if id(model) in visited:
|
||||||
|
return
|
||||||
|
|
||||||
|
visited.add(id(model))
|
||||||
|
for field in model.__fields__:
|
||||||
|
field_data = getattr(model, field)
|
||||||
|
if isinstance(field_data, SecretData):
|
||||||
|
hider(field_data)
|
||||||
|
elif isinstance(field_data, List):
|
||||||
|
if len(field_data) > 0:
|
||||||
|
if not isinstance(field_data[0], BaseModel):
|
||||||
|
continue
|
||||||
|
for data in field_data:
|
||||||
|
cls.hide_secrets(data, hider, visited)
|
||||||
|
elif isinstance(field_data, dict):
|
||||||
|
for key in field_data:
|
||||||
|
if not isinstance(field_data[key], BaseModel):
|
||||||
|
continue
|
||||||
|
cls.hide_secrets(field_data[key], hider, visited)
|
||||||
|
else:
|
||||||
|
if isinstance(field_data, BaseModel):
|
||||||
|
cls.hide_secrets(field_data, hider, visited)
|
||||||
|
|
||||||
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
|
def save(self, new: bool = False, require_etag: bool = False) -> Optional[Error]:
|
||||||
|
self.__class__.hide_secrets(self, save_to_keyvault)
|
||||||
# TODO: migrate to an inspect.signature() model
|
# TODO: migrate to an inspect.signature() model
|
||||||
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
|
raw = self.raw(by_alias=True, exclude_none=True, exclude=self.save_exclude())
|
||||||
for key in raw:
|
for key in raw:
|
||||||
if not isinstance(raw[key], (str, int)):
|
if not isinstance(raw[key], (str, int)):
|
||||||
raw[key] = json.dumps(raw[key])
|
raw[key] = json.dumps(raw[key])
|
||||||
|
|
||||||
# for datetime fields that passed through filtering, use the real value,
|
|
||||||
# rather than a serialized form
|
|
||||||
for field in self.__fields__:
|
for field in self.__fields__:
|
||||||
if field not in raw:
|
if field not in raw:
|
||||||
continue
|
continue
|
||||||
|
# for datetime fields that passed through filtering, use the real value,
|
||||||
|
# rather than a serialized form
|
||||||
if self.__fields__[field].type_ == datetime:
|
if self.__fields__[field].type_ == datetime:
|
||||||
raw[field] = getattr(self, field)
|
raw[field] = getattr(self, field)
|
||||||
|
|
||||||
|
79
src/api-service/__app__/onefuzzlib/secrets.py
Normal file
79
src/api-service/__app__/onefuzzlib/secrets.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Tuple, Type, TypeVar, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from azure.keyvault.secrets import KeyVaultSecret
|
||||||
|
from onefuzztypes.models import SecretAddress, SecretData
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .azure.creds import get_instance_name, get_keyvault_client
|
||||||
|
|
||||||
|
A = TypeVar("A", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_keyvault(secret_data: SecretData) -> None:
|
||||||
|
if isinstance(secret_data.secret, SecretAddress):
|
||||||
|
return
|
||||||
|
|
||||||
|
secret_name = str(uuid4())
|
||||||
|
if isinstance(secret_data.secret, str):
|
||||||
|
secret_value = secret_data.secret
|
||||||
|
elif isinstance(secret_data.secret, BaseModel):
|
||||||
|
secret_value = secret_data.secret.json()
|
||||||
|
else:
|
||||||
|
raise Exception("invalid secret data")
|
||||||
|
|
||||||
|
kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
|
||||||
|
secret_data.secret = SecretAddress(url=kv.id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_secret_string_value(self: SecretData[str]) -> str:
|
||||||
|
if isinstance(self.secret, SecretAddress):
|
||||||
|
secret = get_secret(self.secret.url)
|
||||||
|
return cast(str, secret.value)
|
||||||
|
else:
|
||||||
|
return self.secret
|
||||||
|
|
||||||
|
|
||||||
|
def get_keyvault_address() -> str:
|
||||||
|
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
|
||||||
|
return f"https://{get_instance_name()}-vault.vault.azure.net"
|
||||||
|
|
||||||
|
|
||||||
|
def store_in_keyvault(
|
||||||
|
keyvault_url: str, secret_name: str, secret_value: str
|
||||||
|
) -> KeyVaultSecret:
|
||||||
|
keyvault_client = get_keyvault_client(keyvault_url)
|
||||||
|
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
|
||||||
|
return kvs
|
||||||
|
|
||||||
|
|
||||||
|
def parse_secret_url(secret_url: str) -> Tuple[str, str]:
|
||||||
|
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
|
||||||
|
u = urlparse(secret_url)
|
||||||
|
vault_url = f"{u.scheme}://{u.netloc}"
|
||||||
|
secret_name = u.path.split("/")[2]
|
||||||
|
return (vault_url, secret_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_secret(secret_url: str) -> KeyVaultSecret:
|
||||||
|
(vault_url, secret_name) = parse_secret_url(secret_url)
|
||||||
|
keyvault_client = get_keyvault_client(vault_url)
|
||||||
|
return keyvault_client.get_secret(secret_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_secret_obj(secret_url: str, model: Type[A]) -> A:
|
||||||
|
secret = get_secret(secret_url)
|
||||||
|
return model.parse_raw(secret.value)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_secret(secret_url: str) -> None:
|
||||||
|
(vault_url, secret_name) = parse_secret_url(secret_url)
|
||||||
|
keyvault_client = get_keyvault_client(vault_url)
|
||||||
|
keyvault_client.begin_delete_secret(secret_name).wait()
|
91
src/api-service/tests/test_secrets.py
Normal file
91
src/api-service/tests/test_secrets.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onefuzztypes.enums import OS, ContainerType
|
||||||
|
from onefuzztypes.job_templates import (
|
||||||
|
JobTemplate,
|
||||||
|
JobTemplateIndex,
|
||||||
|
JobTemplateNotification,
|
||||||
|
)
|
||||||
|
from onefuzztypes.models import (
|
||||||
|
JobConfig,
|
||||||
|
Notification,
|
||||||
|
NotificationConfig,
|
||||||
|
SecretAddress,
|
||||||
|
SecretData,
|
||||||
|
TeamsTemplate,
|
||||||
|
)
|
||||||
|
from onefuzztypes.primitives import Container
|
||||||
|
|
||||||
|
from __app__.onefuzzlib.orm import ORMMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecret(unittest.TestCase):
|
||||||
|
def test_hide(self) -> None:
|
||||||
|
def hider(secret_data: SecretData) -> None:
|
||||||
|
if not isinstance(secret_data.secret, SecretAddress):
|
||||||
|
secret_data.secret = SecretAddress(url="blah blah")
|
||||||
|
|
||||||
|
notification = Notification(
|
||||||
|
container=Container("data"),
|
||||||
|
config=TeamsTemplate(url=SecretData(secret="http://test")),
|
||||||
|
)
|
||||||
|
ORMMixin.hide_secrets(notification, hider)
|
||||||
|
|
||||||
|
if isinstance(notification.config, TeamsTemplate):
|
||||||
|
self.assertIsInstance(notification.config.url, SecretData)
|
||||||
|
self.assertIsInstance(notification.config.url.secret, SecretAddress)
|
||||||
|
else:
|
||||||
|
self.fail(f"Invalid config type {type(notification.config)}")
|
||||||
|
|
||||||
|
def test_hide_nested_list(self) -> None:
|
||||||
|
def hider(secret_data: SecretData) -> None:
|
||||||
|
if not isinstance(secret_data.secret, SecretAddress):
|
||||||
|
secret_data.secret = SecretAddress(url="blah blah")
|
||||||
|
|
||||||
|
job_template_index = JobTemplateIndex(
|
||||||
|
name="test",
|
||||||
|
template=JobTemplate(
|
||||||
|
os=OS.linux,
|
||||||
|
job=JobConfig(name="test", build="test", project="test", duration=1),
|
||||||
|
tasks=[],
|
||||||
|
notifications=[
|
||||||
|
JobTemplateNotification(
|
||||||
|
container_type=ContainerType.unique_inputs,
|
||||||
|
notification=NotificationConfig(
|
||||||
|
config=TeamsTemplate(url=SecretData(secret="http://test"))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
user_fields=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
ORMMixin.hide_secrets(job_template_index, hider)
|
||||||
|
notification = job_template_index.template.notifications[0].notification
|
||||||
|
if isinstance(notification.config, TeamsTemplate):
|
||||||
|
self.assertIsInstance(notification.config.url, SecretData)
|
||||||
|
self.assertIsInstance(notification.config.url.secret, SecretAddress)
|
||||||
|
else:
|
||||||
|
self.fail(f"Invalid config type {type(notification.config)}")
|
||||||
|
|
||||||
|
def test_read_secret(self) -> None:
|
||||||
|
json_data = """
|
||||||
|
{
|
||||||
|
"notification_id": "b52b24d1-eec6-46c9-b06a-818a997da43c",
|
||||||
|
"container": "data",
|
||||||
|
"config" : {"url": {"secret": {"url": "http://test"}}}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
data = json.loads(json_data)
|
||||||
|
notification = Notification.parse_obj(data)
|
||||||
|
self.assertIsInstance(notification.config, TeamsTemplate)
|
||||||
|
if isinstance(notification.config, TeamsTemplate):
|
||||||
|
self.assertIsInstance(notification.config.url, SecretData)
|
||||||
|
self.assertIsInstance(notification.config.url.secret, SecretAddress)
|
||||||
|
else:
|
||||||
|
self.fail(f"Invalid config type {type(notification.config)}")
|
@ -11,6 +11,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -381,6 +382,8 @@ def serialize(data: Any) -> Any:
|
|||||||
return str(data)
|
return str(data)
|
||||||
if isinstance(data, (int, str)):
|
if isinstance(data, (int, str)):
|
||||||
return data
|
return data
|
||||||
|
if is_dataclass(data):
|
||||||
|
return {serialize(a): serialize(b) for (a, b) in asdict(data).items()}
|
||||||
|
|
||||||
raise Exception("unknown type %s" % type(data))
|
raise Exception("unknown type %s" % type(data))
|
||||||
|
|
||||||
|
@ -56,7 +56,8 @@
|
|||||||
"Managed Identity Operator": "f1a07417-d97a-45cb-824c-7a7467783830",
|
"Managed Identity Operator": "f1a07417-d97a-45cb-824c-7a7467783830",
|
||||||
"Network Contributor": "4d97b98b-1d4f-4787-a291-c67834d212e7",
|
"Network Contributor": "4d97b98b-1d4f-4787-a291-c67834d212e7",
|
||||||
"Storage Account Contributor": "17d1049b-9a84-46fb-8f53-869881c3d3ab",
|
"Storage Account Contributor": "17d1049b-9a84-46fb-8f53-869881c3d3ab",
|
||||||
"Virtual Machine Contributor": "9980e02c-c2be-4d73-94e8-173b1dc7cf3c"
|
"Virtual Machine Contributor": "9980e02c-c2be-4d73-94e8-173b1dc7cf3c",
|
||||||
|
"keyVaultName": "[concat(parameters('name'), '-vault')]"
|
||||||
},
|
},
|
||||||
"functions": [
|
"functions": [
|
||||||
{
|
{
|
||||||
@ -101,6 +102,34 @@
|
|||||||
"apiVersion": "2018-11-30",
|
"apiVersion": "2018-11-30",
|
||||||
"location": "[resourceGroup().location]"
|
"location": "[resourceGroup().location]"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"type": "Microsoft.KeyVault/vaults",
|
||||||
|
"apiVersion": "2019-09-01",
|
||||||
|
"name": "[variables('keyVaultName')]",
|
||||||
|
"location": "[resourceGroup().location]",
|
||||||
|
"properties": {
|
||||||
|
"enabledForDiskEncryption": false,
|
||||||
|
"enabledForTemplateDeployment": true,
|
||||||
|
"tenantId": "[subscription().tenantId]",
|
||||||
|
"accessPolicies": [
|
||||||
|
{
|
||||||
|
"objectId": "[reference(resourceId('Microsoft.Web/sites', parameters('name')), '2019-08-01', 'full').identity.principalId]",
|
||||||
|
"tenantId": "[subscription().tenantId]",
|
||||||
|
"permissions": {
|
||||||
|
"secrets": ["get", "list", "set", "delete"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sku": {
|
||||||
|
"name": "standard",
|
||||||
|
"family": "A"
|
||||||
|
},
|
||||||
|
"networkAcls": {
|
||||||
|
"defaultAction": "Allow",
|
||||||
|
"bypass": "AzureServices"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"apiVersion": "2018-11-01",
|
"apiVersion": "2018-11-01",
|
||||||
"name": "[parameters('name')]",
|
"name": "[parameters('name')]",
|
||||||
|
@ -4,10 +4,11 @@
|
|||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator, validator
|
from pydantic import BaseModel, Field, root_validator, validator
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from .consts import ONE_HOUR, SEVEN_DAYS
|
from .consts import ONE_HOUR, SEVEN_DAYS
|
||||||
from .enums import (
|
from .enums import (
|
||||||
@ -41,6 +42,39 @@ class UserInfo(BaseModel):
|
|||||||
upn: Optional[str]
|
upn: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
# Stores the address of a secret
|
||||||
|
class SecretAddress(BaseModel):
|
||||||
|
# keyvault address of a secret
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
# This class allows us to store some data that are intended to be secret
|
||||||
|
# The secret field stores either the raw data or the address of that data
|
||||||
|
# This class allows us to maintain backward compatibility with existing
|
||||||
|
# NotificationTemplate classes
|
||||||
|
@dataclass
|
||||||
|
class SecretData(Generic[T]):
|
||||||
|
secret: Union[T, SecretAddress]
|
||||||
|
|
||||||
|
def __init__(self, secret: Union[T, SecretAddress]):
|
||||||
|
if isinstance(secret, dict):
|
||||||
|
self.secret = SecretAddress.parse_obj(secret)
|
||||||
|
else:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if isinstance(self.secret, SecretAddress):
|
||||||
|
return str(self.secret)
|
||||||
|
else:
|
||||||
|
return "[REDACTED]"
|
||||||
|
|
||||||
|
|
||||||
class EnumModel(BaseModel):
|
class EnumModel(BaseModel):
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def exactly_one(cls: Any, values: Any) -> Any:
|
def exactly_one(cls: Any, values: Any) -> Any:
|
||||||
@ -226,7 +260,7 @@ class ADODuplicateTemplate(BaseModel):
|
|||||||
|
|
||||||
class ADOTemplate(BaseModel):
|
class ADOTemplate(BaseModel):
|
||||||
base_url: str
|
base_url: str
|
||||||
auth_token: str
|
auth_token: SecretData[str]
|
||||||
project: str
|
project: str
|
||||||
type: str
|
type: str
|
||||||
unique_fields: List[str]
|
unique_fields: List[str]
|
||||||
@ -234,15 +268,33 @@ class ADOTemplate(BaseModel):
|
|||||||
ado_fields: Dict[str, str]
|
ado_fields: Dict[str, str]
|
||||||
on_duplicate: ADODuplicateTemplate
|
on_duplicate: ADODuplicateTemplate
|
||||||
|
|
||||||
def redact(self) -> None:
|
# validator needed for backward compatibility
|
||||||
self.auth_token = "***"
|
@validator("auth_token", pre=True, always=True)
|
||||||
|
def validate_auth_token(cls, v: Any) -> SecretData:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return SecretData(secret=v)
|
||||||
|
elif isinstance(v, SecretData):
|
||||||
|
return v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
return SecretData(secret=v["secret"])
|
||||||
|
else:
|
||||||
|
raise TypeError(f"invalid datatype {type(v)}")
|
||||||
|
|
||||||
|
|
||||||
class TeamsTemplate(BaseModel):
|
class TeamsTemplate(BaseModel):
|
||||||
url: str
|
url: SecretData[str]
|
||||||
|
|
||||||
def redact(self) -> None:
|
# validator needed for backward compatibility
|
||||||
self.url = "***"
|
@validator("url", pre=True, always=True)
|
||||||
|
def validate_url(cls, v: Any) -> SecretData:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return SecretData(secret=v)
|
||||||
|
elif isinstance(v, SecretData):
|
||||||
|
return v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
return SecretData(secret=v["secret"])
|
||||||
|
else:
|
||||||
|
raise TypeError(f"invalid datatype {type(v)}")
|
||||||
|
|
||||||
|
|
||||||
class ContainerDefinition(BaseModel):
|
class ContainerDefinition(BaseModel):
|
||||||
@ -408,7 +460,7 @@ class GithubAuth(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class GithubIssueTemplate(BaseModel):
|
class GithubIssueTemplate(BaseModel):
|
||||||
auth: GithubAuth
|
auth: SecretData[GithubAuth]
|
||||||
organization: str
|
organization: str
|
||||||
repository: str
|
repository: str
|
||||||
title: str
|
title: str
|
||||||
@ -418,9 +470,20 @@ class GithubIssueTemplate(BaseModel):
|
|||||||
labels: List[str]
|
labels: List[str]
|
||||||
on_duplicate: GithubIssueDuplicate
|
on_duplicate: GithubIssueDuplicate
|
||||||
|
|
||||||
def redact(self) -> None:
|
# validator needed for backward compatibility
|
||||||
self.auth.user = "***"
|
@validator("auth", pre=True, always=True)
|
||||||
self.auth.personal_access_token = "***"
|
def validate_auth(cls, v: Any) -> SecretData:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return SecretData(secret=v)
|
||||||
|
elif isinstance(v, SecretData):
|
||||||
|
return v
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
try:
|
||||||
|
return SecretData(GithubAuth.parse_obj(v))
|
||||||
|
except Exception:
|
||||||
|
return SecretData(secret=v["secret"])
|
||||||
|
else:
|
||||||
|
raise TypeError(f"invalid datatype {type(v)}")
|
||||||
|
|
||||||
|
|
||||||
NotificationTemplate = Union[ADOTemplate, TeamsTemplate, GithubIssueTemplate]
|
NotificationTemplate = Union[ADOTemplate, TeamsTemplate, GithubIssueTemplate]
|
||||||
|
@ -5,14 +5,34 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from onefuzztypes.models import Scaleset, SecretData, TeamsTemplate
|
||||||
|
|
||||||
from onefuzztypes.models import Scaleset, TeamsTemplate
|
|
||||||
from onefuzztypes.requests import NotificationCreate
|
from onefuzztypes.requests import NotificationCreate
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|
||||||
class TestModelsVerify(unittest.TestCase):
|
class TestModelsVerify(unittest.TestCase):
|
||||||
def test_model(self) -> None:
|
def test_model(self) -> None:
|
||||||
|
data = {
|
||||||
|
"container": "data",
|
||||||
|
"config": {"url": {"secret": "https://www.contoso.com/"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
notification = NotificationCreate.parse_obj(data)
|
||||||
|
self.assertIsInstance(notification.config, TeamsTemplate)
|
||||||
|
self.assertIsInstance(notification.config.url, SecretData)
|
||||||
|
self.assertEqual(
|
||||||
|
notification.config.url.secret,
|
||||||
|
"https://www.contoso.com/",
|
||||||
|
"mismatch secret value",
|
||||||
|
)
|
||||||
|
|
||||||
|
missing_container = {
|
||||||
|
"config": {"url": "https://www.contoso.com/"},
|
||||||
|
}
|
||||||
|
with self.assertRaises(ValidationError):
|
||||||
|
NotificationCreate.parse_obj(missing_container)
|
||||||
|
|
||||||
|
def test_legacy_model(self) -> None:
|
||||||
data = {
|
data = {
|
||||||
"container": "data",
|
"container": "data",
|
||||||
"config": {"url": "https://www.contoso.com/"},
|
"config": {"url": "https://www.contoso.com/"},
|
||||||
@ -20,6 +40,7 @@ class TestModelsVerify(unittest.TestCase):
|
|||||||
|
|
||||||
notification = NotificationCreate.parse_obj(data)
|
notification = NotificationCreate.parse_obj(data)
|
||||||
self.assertIsInstance(notification.config, TeamsTemplate)
|
self.assertIsInstance(notification.config, TeamsTemplate)
|
||||||
|
self.assertIsInstance(notification.config.url, SecretData)
|
||||||
|
|
||||||
missing_container = {
|
missing_container = {
|
||||||
"config": {"url": "https://www.contoso.com/"},
|
"config": {"url": "https://www.contoso.com/"},
|
||||||
|
Reference in New Issue
Block a user