mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
82 lines
2.5 KiB
Python
82 lines
2.5 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
|
|
import os
|
|
from typing import Tuple, Type, TypeVar, cast
|
|
from urllib.parse import urlparse
|
|
from uuid import uuid4
|
|
|
|
from azure.keyvault.secrets import KeyVaultSecret
|
|
from onefuzztypes.models import SecretAddress, SecretData
|
|
from pydantic import BaseModel
|
|
|
|
from .azure.creds import get_keyvault_client
|
|
|
|
A = TypeVar("A", bound=BaseModel)
|
|
|
|
|
|
def save_to_keyvault(secret_data: SecretData) -> None:
|
|
if isinstance(secret_data.secret, SecretAddress):
|
|
return
|
|
|
|
secret_name = str(uuid4())
|
|
if isinstance(secret_data.secret, str):
|
|
secret_value = secret_data.secret
|
|
elif isinstance(secret_data.secret, BaseModel):
|
|
secret_value = secret_data.secret.json()
|
|
else:
|
|
raise Exception("invalid secret data")
|
|
|
|
kv = store_in_keyvault(get_keyvault_address(), secret_name, secret_value)
|
|
secret_data.secret = SecretAddress(url=kv.id)
|
|
|
|
|
|
def get_secret_string_value(self: SecretData[str]) -> str:
|
|
if isinstance(self.secret, SecretAddress):
|
|
secret = get_secret(self.secret.url)
|
|
return cast(str, secret.value)
|
|
else:
|
|
return self.secret
|
|
|
|
|
|
def get_keyvault_address() -> str:
|
|
# https://docs.microsoft.com/en-us/azure/key-vault/general/about-keys-secrets-certificates#vault-name-and-object-name
|
|
keyvault_name = os.environ["ONEFUZZ_KEYVAULT"]
|
|
return f"https://{keyvault_name}.vault.azure.net"
|
|
|
|
|
|
def store_in_keyvault(
|
|
keyvault_url: str, secret_name: str, secret_value: str
|
|
) -> KeyVaultSecret:
|
|
keyvault_client = get_keyvault_client(keyvault_url)
|
|
kvs: KeyVaultSecret = keyvault_client.set_secret(secret_name, secret_value)
|
|
return kvs
|
|
|
|
|
|
def parse_secret_url(secret_url: str) -> Tuple[str, str]:
|
|
# format: https://{vault-name}.vault.azure.net/secrets/{secret-name}/{version}
|
|
u = urlparse(secret_url)
|
|
vault_url = f"{u.scheme}://{u.netloc}"
|
|
secret_name = u.path.split("/")[2]
|
|
return (vault_url, secret_name)
|
|
|
|
|
|
def get_secret(secret_url: str) -> KeyVaultSecret:
|
|
(vault_url, secret_name) = parse_secret_url(secret_url)
|
|
keyvault_client = get_keyvault_client(vault_url)
|
|
return keyvault_client.get_secret(secret_name)
|
|
|
|
|
|
def get_secret_obj(secret_url: str, model: Type[A]) -> A:
|
|
secret = get_secret(secret_url)
|
|
return model.parse_raw(secret.value)
|
|
|
|
|
|
def delete_secret(secret_url: str) -> None:
|
|
(vault_url, secret_name) = parse_secret_url(secret_url)
|
|
keyvault_client = get_keyvault_client(vault_url)
|
|
keyvault_client.begin_delete_secret(secret_name).wait()
|