Files
onefuzz/src/api-service/__app__/onefuzzlib/azure/creds.py
Cheick Keita 98cd7c9c56 migrate to msgraph (#966)
* migrate to msgraph

* add subscription id to query_microsoft_graph

* migrating remaingin references

* formatting

* adding missing dependencies

* flake fix

* fix get_tenant_id

* cleanup

* formatting

* migrate application creation in deploy.py

* foramt

* mypy fix

* isort

* isort

* format

* bug fixes

* specify the correct signInAudience

* fix backing service principal creation
fix preauthorized application

* remove remaining references to graphrbac

* fix ms graph authentication

* formatting

* fix typo

* format

* deployment fix

* set implicitGrantSettings in the deployment

* format

* fix deployment

* fix graph authentication on the server

* use the current cli logged in account to retrive the backend token cache

* assign the the msgraph app role permissions to the web app during the deployment

* formatting

* fix build

* build fix

* fix bandit issue

* mypy fix

* isort

* deploy fixes

* formatting

* remove assign_app_permissions

* mypy fix

* build fix

* mypy fix

* format

* formatting

* flake fix

* remove webapp identity permission assignment

* remove unused reference to assign_app_role

* remove manual registration message

* fixing name and logging

* address PR coments

* address PR comments

* build fix

* lint

* lint

* mypy fix

* mypy fix

* formatting

* address PR comments

* linting

* lint

* remove ONEFUZZ_AAD_GROUP_ID check

* regenerate webhook_events.md

* change return type of query_microsoft_graph_list

* fix tenant_id

Co-authored-by: Marc Greisen <marc@greisen.org>
Co-authored-by: Stas <stishkin@live.com>
2021-10-22 11:59:05 -07:00

250 lines
7.2 KiB
Python

#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import functools
import logging
import os
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
from uuid import UUID
import requests
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.subscription import SubscriptionClient
from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id
from onefuzztypes.primitives import Container, Region
from .monkeypatch import allow_more_workers, reduce_logging
# https://docs.microsoft.com/en-us/graph/api/overview?view=graph-rest-1.0
GRAPH_RESOURCE = "https://graph.microsoft.com"
GRAPH_RESOURCE_ENDPOINT = "https://graph.microsoft.com/v1.0"
@cached
def get_msi() -> MSIAuthentication:
allow_more_workers()
reduce_logging()
return MSIAuthentication()
@cached
def get_identity() -> DefaultAzureCredential:
allow_more_workers()
reduce_logging()
return DefaultAzureCredential()
@cached
def get_base_resource_group() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_RESOURCE_GROUP"])["resource_group"]
@cached
def get_base_region() -> Region:
client = ResourceManagementClient(
credential=get_identity(), subscription_id=get_subscription()
)
group = client.resource_groups.get(get_base_resource_group())
return Region(group.location)
@cached
def get_subscription() -> Any: # should be str
return parse_resource_id(os.environ["ONEFUZZ_DATA_STORAGE"])["subscription"]
@cached
def get_insights_instrumentation_key() -> Any: # should be str
return os.environ["APPINSIGHTS_INSTRUMENTATIONKEY"]
@cached
def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"]
@cached
def get_instance_name() -> str:
return os.environ["ONEFUZZ_INSTANCE_NAME"]
@cached
def get_instance_url() -> str:
return "https://%s.azurewebsites.net" % get_instance_name()
@cached
def get_instance_id() -> UUID:
from .containers import get_blob
from .storage import StorageType
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
if blob is None:
raise Exception("missing instance_id")
return UUID(blob.decode())
DAY_IN_SECONDS = 60 * 60 * 24
@cached(ttl=DAY_IN_SECONDS)
def get_regions() -> List[Region]:
subscription = get_subscription()
client = SubscriptionClient(credential=get_identity())
locations = client.subscriptions.list_locations(subscription)
return sorted([Region(x.name) for x in locations])
class GraphQueryError(Exception):
def __init__(self, message: str, status_code: Optional[int]) -> None:
super(GraphQueryError, self).__init__(message)
self.message = message
self.status_code = status_code
def query_microsoft_graph(
method: str,
resource: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
) -> Dict:
cred = get_identity()
access_token = cred.get_token(f"{GRAPH_RESOURCE}/.default")
url = urllib.parse.urljoin(f"{GRAPH_RESOURCE_ENDPOINT}/", resource)
headers = {
"Authorization": "Bearer %s" % access_token.token,
"Content-Type": "application/json",
}
response = requests.request(
method=method, url=url, headers=headers, params=params, json=body
)
if 200 <= response.status_code < 300:
if response.content and response.content.strip():
json = response.json()
if isinstance(json, Dict):
return json
else:
raise GraphQueryError(
"invalid data expected a json object: HTTP"
f" {response.status_code} - {json}",
response.status_code,
)
else:
return {}
else:
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
raise GraphQueryError(
f"request did not succeed: HTTP {response.status_code} - {error_text}",
response.status_code,
)
def query_microsoft_graph_list(
method: str,
resource: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
) -> List[Any]:
result = query_microsoft_graph(
method,
resource,
params,
body,
)
value = result.get("value")
if isinstance(value, list):
return value
else:
raise GraphQueryError("Expected data containing a list of values", None)
def is_member_of(group_id: str, member_id: str) -> bool:
body = {"groupIds": [group_id]}
response = query_microsoft_graph_list(
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
)
return group_id in response
@cached
def get_scaleset_identity_resource_path() -> str:
scaleset_id_name = "%s-scalesetid" % get_instance_name()
resource_group_path = "/subscriptions/%s/resourceGroups/%s/providers" % (
get_subscription(),
get_base_resource_group(),
)
return "%s/Microsoft.ManagedIdentity/userAssignedIdentities/%s" % (
resource_group_path,
scaleset_id_name,
)
@cached
def get_scaleset_principal_id() -> UUID:
api_version = "2018-11-30" # matches the apiversion in the deployment template
client = ResourceManagementClient(
credential=get_identity(), subscription_id=get_subscription()
)
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
return UUID(uid.properties["principalId"])
@cached
def get_keyvault_client(vault_url: str) -> SecretClient:
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
def clear_azure_client_cache() -> None:
# clears the memoization of the Azure clients.
from .compute import get_compute_client
from .containers import get_blob_service
from .network_mgmt_client import get_network_client
from .storage import get_mgmt_client
# currently memoization.cache does not project the wrapped function's types.
# As a workaround, CI comments out the `cached` wrapper, then runs the type
# validation. This enables calling the wrapper's clear_cache if it's not
# disabled.
for func in [
get_msi,
get_identity,
get_compute_client,
get_blob_service,
get_network_client,
get_mgmt_client,
]:
clear_func = getattr(func, "clear_cache", None)
if clear_func is not None:
clear_func()
T = TypeVar("T", bound=Callable[..., Any])
class retry_on_auth_failure:
def __call__(self, func: T) -> T:
@functools.wraps(func)
def decorated(*args, **kwargs): # type: ignore
try:
return func(*args, **kwargs)
except ClientAuthenticationError as err:
logging.warning(
"clearing authentication cache after auth failure: %s", err
)
clear_azure_client_cache()
return func(*args, **kwargs)
return cast(T, decorated)