assign scaleset to a role (#185)

This commit is contained in:
Cheick Keita
2020-10-28 12:13:31 -07:00
committed by GitHub
parent 59cfc52e9b
commit e76064b340
3 changed files with 207 additions and 119 deletions

View File

@ -12,18 +12,20 @@ import shutil
import subprocess
import sys
import tempfile
import time
import uuid
import zipfile
from datetime import datetime, timedelta
from azure.cli.core import CLIError
from azure.common.client_factory import get_client_from_cli_profile
from azure.common.credentials import get_cli_profile
from azure.core.exceptions import ResourceExistsError
from azure.cosmosdb.table.tableservice import TableService
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import (
Application,
ApplicationCreateParameters,
ApplicationUpdateParameters,
AppRole,
GraphErrorException,
OptionalClaims,
@ -48,7 +50,6 @@ from azure.mgmt.resource.resources.models import (
DeploymentMode,
DeploymentProperties,
)
import time
from azure.mgmt.storage import StorageManagementClient
from azure.storage.blob import (
BlobServiceClient,
@ -59,12 +60,12 @@ from azure.storage.queue import QueueServiceClient
from msrest.serialization import TZ_UTC
from data_migration import migrate
from register_pool_application import (
from registration import (
add_application_password,
authorize_application,
get_application,
register_application,
update_registration,
update_pool_registration,
)
USER_IMPERSONATION = "311a71cc-e848-46a1-bdf8-97ff7156d8e6"
@ -225,12 +226,11 @@ class Client:
while True:
time.sleep(wait)
count += 1
try:
return add_application_password(object_id)
except CLIError as err:
password = add_application_password(object_id)
if password:
return password
if count > timeout_seconds/wait:
raise err
logger.info("creating password failed, trying again")
raise Exception("creating password failed, trying again")
def setup_rbac(self):
"""
@ -256,6 +256,25 @@ class Client:
logger.error("unable to query RBAC. Provide client_id and client_secret")
sys.exit(1)
app_roles = [
AppRole(
allowed_member_types=["Application"],
display_name="CliClient",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allows access from the CLI.",
value="CliClient",
),
AppRole(
allowed_member_types=["Application"],
display_name="ManagedNode",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allow access from a lab machine.",
value="ManagedNode",
),
]
if not existing:
logger.info("creating Application registration")
url = "https://%s.azurewebsites.net" % self.application_name
@ -273,24 +292,7 @@ class Client:
resource_app_id="00000002-0000-0000-c000-000000000000",
)
],
app_roles=[
AppRole(
allowed_member_types=["Application"],
display_name="CliClient",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allows access from the CLI.",
value="CliClient",
),
AppRole(
allowed_member_types=["Application"],
display_name="LabMachine",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allow access from a lab machine.",
value="LabMachine",
),
],
app_roles=app_roles,
)
app = client.applications.create(params)
@ -303,7 +305,27 @@ class Client:
)
client.service_principals.create(service_principal_params)
else:
app = existing[0]
app: Application = existing[0]
existing_role_values = [app_role.value for app_role in app.app_roles]
has_missing_roles = any(
[role.value not in existing_role_values for role in app_roles]
)
if has_missing_roles:
# disabling the existing app role first to allow the update
# this is a requirement to update the application roles
for role in app.app_roles:
role.is_enabled = False
client.applications.patch(
app.object_id, ApplicationUpdateParameters(app_roles=app.app_roles)
)
# overriding the list of app roles
client.applications.patch(
app.object_id, ApplicationUpdateParameters(app_roles=app_roles)
)
creds = list(client.applications.list_password_credentials(app.object_id))
client.applications.update_password_credentials(app.object_id, creds)
@ -612,7 +634,7 @@ class Client:
def update_registration(self):
if not self.create_registration:
return
update_registration(self.application_name)
update_pool_registration(self.application_name)
def done(self):
logger.info(TELEMETRY_NOTICE)
@ -766,19 +788,6 @@ def main():
logging.getLogger("deploy").setLevel(logging.INFO)
# TODO: using az_cli resets logging defaults. For now, force these
# to be WARN level
if not args.verbose:
for entry in [
"adal-python",
"msrest.universal_http",
"urllib3.connectionpool",
"az_command_data_logger",
"msrest.service_client",
"azure.core.pipeline.policies.http_logging_policy",
]:
logging.getLogger(entry).setLevel(logging.WARN)
if args.start_at != states[0][0]:
logger.warning(
"*** Starting at a non-standard deployment state. "

View File

@ -4,15 +4,15 @@
# Licensed under the MIT License.
import argparse
import json
import logging
import os
import urllib.parse
from datetime import datetime, timedelta
from typing import Dict, List, NamedTuple, Optional, Tuple
from uuid import UUID, uuid4
from azure.cli.core import get_default_cli # type: ignore
import requests
from azure.common.client_factory import get_client_from_cli_profile
from azure.common.credentials import get_cli_profile
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import (
Application,
@ -26,13 +26,41 @@ from msrest.serialization import TZ_UTC
logger = logging.getLogger("deploy")
def az_cli(args):
cli = get_default_cli()
cli.invoke(args, out_file=open(os.devnull, "w"))
if cli.result.result:
return cli.result.result
elif cli.result.error:
raise cli.result.error
class GraphQueryError(Exception):
pass
def query_microsoft_graph(
method: str,
resource: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
):
profile = get_cli_profile()
(token_type, access_token, _), _, _ = profile.get_raw_token(
resource="https://graph.microsoft.com"
)
url = urllib.parse.urljoin("https://graph.microsoft.com/v1.0/", resource)
headers = {
"Authorization": "%s %s" % (token_type, access_token),
"Content-Type": "application/json",
}
response = requests.request(
method=method, url=url, headers=headers, params=params, json=body
)
response.status_code
if 200 <= response.status_code < 300:
try:
return response.json()
except ValueError:
return None
else:
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
raise GraphQueryError(
"request did not succeed: HTTP %s - %s" % (response.status_code, error_text)
)
class ApplicationInfo(NamedTuple):
@ -129,32 +157,22 @@ def create_application_registration(
)
registered_app: Application = client.applications.create(params)
body = {
query_microsoft_graph(
method="PATCH",
resource="applications/%s" % registered_app.object_id,
body={
"publicClient": {
"redirectUris": ["https://%s.azurewebsites.net" % onefuzz_instance_name]
},
"isFallbackPublicClient": True,
}
az_cli(
[
"rest",
"-m",
"PATCH",
"-u",
"https://graph.microsoft.com/v1.0/applications/%s"
% registered_app.object_id,
"--headers",
"Content-Type=application/json",
"-b",
json.dumps(body),
]
},
)
authorize_application(UUID(registered_app.app_id), UUID(app.app_id))
return registered_app
def add_application_password(app_object_id: UUID) -> Tuple[str, str]:
def add_application_password(app_object_id: UUID) -> Optional[Tuple[str, str]]:
key = uuid4()
password_request = {
"passwordCredential": {
@ -166,36 +184,25 @@ def add_application_password(app_object_id: UUID) -> Tuple[str, str]:
),
}
}
password: Dict = az_cli(
[
"rest",
"-m",
"POST",
"-u",
"https://graph.microsoft.com/v1.0/applications/%s/addPassword"
% app_object_id,
"-b",
json.dumps(password_request),
]
try:
password: Dict = query_microsoft_graph(
method="POST",
resource="applications/%s/addPassword" % app_object_id,
body=password_request,
)
return (str(key), password["secretText"])
except GraphQueryError as err:
logger.warning("creating password failed : %s" % err)
None
def get_application(app_id: UUID) -> Optional[Dict]:
apps: Dict = az_cli(
[
"rest",
"-m",
"GET",
"-u",
"https://graph.microsoft.com/v1.0/applications",
"--uri-parameters",
"$filter=appId eq '%s'" % app_id,
]
apps: Dict = query_microsoft_graph(
method="GET",
resource="applications",
params={"$filter": "appId eq '%s'" % app_id},
)
if len(apps["value"]) == 0:
return None
@ -234,24 +241,16 @@ def authorize_application(
.map(lambda data: {"appId": data[0], "delegatedPermissionIds": data[1]})
)
body = {"api": {"preAuthorizedApplications": preAuthorizedApplications.to_list()}}
az_cli(
[
"rest",
"-m",
"PATCH",
"-u",
"https://graph.microsoft.com/v1.0/applications/%s" % onefuzz_app["id"],
"--headers",
"Content-Type=application/json",
"-b",
json.dumps(body),
]
query_microsoft_graph(
method="PATCH",
resource="applications/%s" % onefuzz_app["id"],
body={
"api": {"preAuthorizedApplications": preAuthorizedApplications.to_list()}
},
)
def update_registration(application_name: str):
def update_pool_registration(application_name: str):
logger.info("Updating application registration")
application_info = register_application(
@ -264,8 +263,72 @@ def update_registration(application_name: str):
logger.info("client_secret: %s" % application_info.client_secret)
def assign_scaleset_role(onefuzz_instance_name: str, scaleset_name: str):
""" Allows the nodes in the scaleset to access the service by assigning their managed identity to the ManagedNode Role """
onefuzz_service_appId = query_microsoft_graph(
method="GET",
resource="applications",
params={
"$filter": "displayName eq '%s'" % onefuzz_instance_name,
"$select": "appId",
},
)
if len(onefuzz_service_appId["value"]) == 0:
raise Exception("onefuzz app registration not found")
appId = onefuzz_service_appId["value"][0]["appId"]
onefuzz_service_principals = query_microsoft_graph(
method="GET",
resource="servicePrincipals",
params={"$filter": "appId eq '%s'" % appId},
)
if len(onefuzz_service_principals["value"]) == 0:
raise Exception("onefuzz app service principal not found")
onefuzz_service_principal = onefuzz_service_principals["value"][0]
scaleset_service_principals = query_microsoft_graph(
method="GET",
resource="servicePrincipals",
params={"$filter": "displayName eq '%s'" % scaleset_name},
)
if len(scaleset_service_principals["value"]) == 0:
raise Exception("scaleset service principal not found")
scaleset_service_principal = scaleset_service_principals["value"][0]
lab_machine_role = (
seq(onefuzz_service_principal["appRoles"])
.filter(lambda x: x["value"] == "ManagedNode")
.head_option()
)
if not lab_machine_role:
raise Exception(
"ManagedNode role not found int the onefuzz application registration. Please redeploy the instance"
)
query_microsoft_graph(
method="POST",
resource="servicePrincipals/%s/appRoleAssignedTo"
% scaleset_service_principal["id"],
body={
"principalId": scaleset_service_principal["id"],
"resourceId": onefuzz_service_principal["id"],
"appRoleId": lab_machine_role["id"],
},
)
def main():
formatter = argparse.ArgumentDefaultsHelpFormatter
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument(
"onefuzz_instance", help="the name of the onefuzz instance"
)
parser = argparse.ArgumentParser(
formatter_class=formatter,
description=(
@ -273,8 +336,19 @@ def main():
"generate a password for the pool agent"
),
)
parser.add_argument("application_name")
parser.add_argument("-v", "--verbose", action="store_true")
subparsers = parser.add_subparsers(title="commands", dest="command")
subparsers.add_parser("update_pool_registration", parents=[parent_parser])
role_assignment_parser = subparsers.add_parser(
"assign_scaleset_role",
parents=[parent_parser],
)
role_assignment_parser.add_argument(
"scaleset_name",
help="the name of the scaleset",
)
args = parser.parse_args()
if args.verbose:
level = logging.DEBUG
@ -284,7 +358,12 @@ def main():
logging.basicConfig(format="%(levelname)s:%(message)s", level=level)
logging.getLogger("deploy").setLevel(logging.INFO)
update_registration(args.application_name)
if args.command == "update_pool_registration":
update_pool_registration(args.onefuzz_instance)
elif args.command == "assign_scaleset_role":
assign_scaleset_role(args.onefuzz_instance, args.scaleset_name)
else:
raise Exception("invalid arguments")
if __name__ == "__main__":