mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-15 19:38:11 +00:00
* migrate to msgraph * add subscription id to query_microsoft_graph * migrating remaingin references * formatting * adding missing dependencies * flake fix * fix get_tenant_id * cleanup * formatting * migrate application creation in deploy.py * foramt * mypy fix * isort * isort * format * bug fixes * specify the correct signInAudience * fix backing service principal creation fix preauthorized application * remove remaining references to graphrbac * fix ms graph authentication * formatting * fix typo * format * deployment fix * set implicitGrantSettings in the deployment * format * fix deployment * fix graph authentication on the server * use the current cli logged in account to retrive the backend token cache * assign the the msgraph app role permissions to the web app during the deployment * formatting * fix build * build fix * fix bandit issue * mypy fix * isort * deploy fixes * formatting * remove assign_app_permissions * mypy fix * build fix * mypy fix * format * formatting * flake fix * remove webapp identity permission assignment * remove unused reference to assign_app_role * remove manual registration message * fixing name and logging * address PR coments * address PR comments * build fix * lint * lint * mypy fix * mypy fix * formatting * address PR comments * linting * lint * remove ONEFUZZ_AAD_GROUP_ID check * regenerate webhook_events.md * change return type of query_microsoft_graph_list * fix tenant_id Co-authored-by: Marc Greisen <marc@greisen.org> Co-authored-by: Stas <stishkin@live.com>
250 lines
7.2 KiB
Python
250 lines
7.2 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import functools
|
|
import logging
|
|
import os
|
|
import urllib.parse
|
|
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
|
|
from uuid import UUID
|
|
|
|
import requests
|
|
from azure.core.exceptions import ClientAuthenticationError
|
|
from azure.identity import DefaultAzureCredential
|
|
from azure.keyvault.secrets import SecretClient
|
|
from azure.mgmt.resource import ResourceManagementClient
|
|
from azure.mgmt.subscription import SubscriptionClient
|
|
from memoization import cached
|
|
from msrestazure.azure_active_directory import MSIAuthentication
|
|
from msrestazure.tools import parse_resource_id
|
|
from onefuzztypes.primitives import Container, Region
|
|
|
|
from .monkeypatch import allow_more_workers, reduce_logging
|
|
|
|
# https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
|
|
GRAPH_RESOURCE = "https://graph.microsoft.com"
|
|
GRAPH_RESOURCE_ENDPOINT = "https://graph.microsoft.com/v1.0"
|
|
|
|
|
|
@cached
|
|
def get_msi() -> MSIAuthentication:
|
|
allow_more_workers()
|
|
reduce_logging()
|
|
return MSIAuthentication()
|
|
|
|
|
|
@cached
|
|
def get_identity() -> DefaultAzureCredential:
|
|
allow_more_workers()
|
|
reduce_logging()
|
|
return DefaultAzureCredential()
|
|
|
|
|
|
@cached
|
|
def get_base_resource_group() -> Any: # should be str
|
|
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
|
|
|
|
|
|
@cached
|
|
def get_base_region() -> Region:
|
|
client = ResourceManagementClient(
|
|
credential=get_identity(), subscription_id=get_subscription()
|
|
)
|
|
group = client.resource_groups.get(get_base_resource_group())
|
|
return Region(group.location)
|
|
|
|
|
|
@cached
|
|
def get_subscription() -> Any: # should be str
|
|
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
|
|
|
|
|
|
@cached
|
|
def get_insights_instrumentation_key() -> Any: # should be str
|
|
return os.environ["APPINSIGHTS_INSTRUMENTATIONKEY"]
|
|
|
|
|
|
@cached
|
|
def get_insights_appid() -> str:
|
|
return os.environ["APPINSIGHTS_APPID"]
|
|
|
|
|
|
@cached
|
|
def get_instance_name() -> str:
|
|
return os.environ["ONEFUZZ_INSTANCE_NAME"]
|
|
|
|
|
|
@cached
|
|
def get_instance_url() -> str:
|
|
return "https://%s.azurewebsites.net" % get_instance_name()
|
|
|
|
|
|
@cached
|
|
def get_instance_id() -> UUID:
|
|
from .containers import get_blob
|
|
from .storage import StorageType
|
|
|
|
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
|
|
if blob is None:
|
|
raise Exception("missing instance_id")
|
|
return UUID(blob.decode())
|
|
|
|
|
|
DAY_IN_SECONDS = 60 * 60 * 24
|
|
|
|
|
|
@cached(ttl=DAY_IN_SECONDS)
|
|
def get_regions() -> List[Region]:
|
|
subscription = get_subscription()
|
|
client = SubscriptionClient(credential=get_identity())
|
|
locations = client.subscriptions.list_locations(subscription)
|
|
return sorted([Region(x.name) for x in locations])
|
|
|
|
|
|
class GraphQueryError(Exception):
|
|
def __init__(self, message: str, status_code: Optional[int]) -> None:
|
|
super(GraphQueryError, self).__init__(message)
|
|
self.message = message
|
|
self.status_code = status_code
|
|
|
|
|
|
def query_microsoft_graph(
|
|
method: str,
|
|
resource: str,
|
|
params: Optional[Dict] = None,
|
|
body: Optional[Dict] = None,
|
|
) -> Dict:
|
|
cred = get_identity()
|
|
access_token = cred.get_token(f"{GRAPH_RESOURCE}/.default")
|
|
|
|
url = urllib.parse.urljoin(f"{GRAPH_RESOURCE_ENDPOINT}/", resource)
|
|
headers = {
|
|
"Authorization": "Bearer %s" % access_token.token,
|
|
"Content-Type": "application/json",
|
|
}
|
|
response = requests.request(
|
|
method=method, url=url, headers=headers, params=params, json=body
|
|
)
|
|
|
|
if 200 <= response.status_code < 300:
|
|
if response.content and response.content.strip():
|
|
json = response.json()
|
|
if isinstance(json, Dict):
|
|
return json
|
|
else:
|
|
raise GraphQueryError(
|
|
"invalid data expected a json object: HTTP"
|
|
f" {response.status_code} - {json}",
|
|
response.status_code,
|
|
)
|
|
else:
|
|
return {}
|
|
else:
|
|
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
|
|
raise GraphQueryError(
|
|
f"request did not succeed: HTTP {response.status_code} - {error_text}",
|
|
response.status_code,
|
|
)
|
|
|
|
|
|
def query_microsoft_graph_list(
|
|
method: str,
|
|
resource: str,
|
|
params: Optional[Dict] = None,
|
|
body: Optional[Dict] = None,
|
|
) -> List[Any]:
|
|
result = query_microsoft_graph(
|
|
method,
|
|
resource,
|
|
params,
|
|
body,
|
|
)
|
|
value = result.get("value")
|
|
if isinstance(value, list):
|
|
return value
|
|
else:
|
|
raise GraphQueryError("Expected data containing a list of values", None)
|
|
|
|
|
|
def is_member_of(group_id: str, member_id: str) -> bool:
|
|
body = {"groupIds": [group_id]}
|
|
response = query_microsoft_graph_list(
|
|
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
|
|
)
|
|
return group_id in response
|
|
|
|
|
|
@cached
|
|
def get_scaleset_identity_resource_path() -> str:
|
|
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
|
resource_group_path = "/subscriptions/%s/resourceGroups/%s/providers" % (
|
|
get_subscription(),
|
|
get_base_resource_group(),
|
|
)
|
|
return "%s/Microsoft.ManagedIdentity/userAssignedIdentities/%s" % (
|
|
resource_group_path,
|
|
scaleset_id_name,
|
|
)
|
|
|
|
|
|
@cached
|
|
def get_scaleset_principal_id() -> UUID:
|
|
api_version = "2018-11-30" # matches the apiversion in the deployment template
|
|
client = ResourceManagementClient(
|
|
credential=get_identity(), subscription_id=get_subscription()
|
|
)
|
|
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
|
|
return UUID(uid.properties["principalId"])
|
|
|
|
|
|
@cached
|
|
def get_keyvault_client(vault_url: str) -> SecretClient:
|
|
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
|
|
|
|
|
def clear_azure_client_cache() -> None:
|
|
# clears the memoization of the Azure clients.
|
|
|
|
from .compute import get_compute_client
|
|
from .containers import get_blob_service
|
|
from .network_mgmt_client import get_network_client
|
|
from .storage import get_mgmt_client
|
|
|
|
# currently memoization.cache does not project the wrapped function's types.
|
|
# As a workaround, CI comments out the `cached` wrapper, then runs the type
|
|
# validation. This enables calling the wrapper's clear_cache if it's not
|
|
# disabled.
|
|
for func in [
|
|
get_msi,
|
|
get_identity,
|
|
get_compute_client,
|
|
get_blob_service,
|
|
get_network_client,
|
|
get_mgmt_client,
|
|
]:
|
|
clear_func = getattr(func, "clear_cache", None)
|
|
if clear_func is not None:
|
|
clear_func()
|
|
|
|
|
|
T = TypeVar("T", bound=Callable[..., Any])
|
|
|
|
|
|
class retry_on_auth_failure:
|
|
def __call__(self, func: T) -> T:
|
|
@functools.wraps(func)
|
|
def decorated(*args, **kwargs): # type: ignore
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except ClientAuthenticationError as err:
|
|
logging.warning(
|
|
"clearing authentication cache after auth failure: %s", err
|
|
)
|
|
|
|
clear_azure_client_cache()
|
|
return func(*args, **kwargs)
|
|
|
|
return cast(T, decorated)
|