assign scaleset to a role (#185)

This commit is contained in:
Cheick Keita
2020-10-28 12:13:31 -07:00
committed by GitHub
parent 59cfc52e9b
commit e76064b340
3 changed files with 207 additions and 119 deletions

View File

@ -12,18 +12,20 @@ import shutil
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import time
import uuid import uuid
import zipfile import zipfile
from datetime import datetime, timedelta from datetime import datetime, timedelta
from azure.cli.core import CLIError
from azure.common.client_factory import get_client_from_cli_profile from azure.common.client_factory import get_client_from_cli_profile
from azure.common.credentials import get_cli_profile from azure.common.credentials import get_cli_profile
from azure.core.exceptions import ResourceExistsError from azure.core.exceptions import ResourceExistsError
from azure.cosmosdb.table.tableservice import TableService from azure.cosmosdb.table.tableservice import TableService
from azure.graphrbac import GraphRbacManagementClient from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import ( from azure.graphrbac.models import (
Application,
ApplicationCreateParameters, ApplicationCreateParameters,
ApplicationUpdateParameters,
AppRole, AppRole,
GraphErrorException, GraphErrorException,
OptionalClaims, OptionalClaims,
@ -48,7 +50,6 @@ from azure.mgmt.resource.resources.models import (
DeploymentMode, DeploymentMode,
DeploymentProperties, DeploymentProperties,
) )
import time
from azure.mgmt.storage import StorageManagementClient from azure.mgmt.storage import StorageManagementClient
from azure.storage.blob import ( from azure.storage.blob import (
BlobServiceClient, BlobServiceClient,
@ -59,12 +60,12 @@ from azure.storage.queue import QueueServiceClient
from msrest.serialization import TZ_UTC from msrest.serialization import TZ_UTC
from data_migration import migrate from data_migration import migrate
from register_pool_application import ( from registration import (
add_application_password, add_application_password,
authorize_application, authorize_application,
get_application, get_application,
register_application, register_application,
update_registration, update_pool_registration,
) )
USER_IMPERSONATION = "311a71cc-e848-46a1-bdf8-97ff7156d8e6" USER_IMPERSONATION = "311a71cc-e848-46a1-bdf8-97ff7156d8e6"
@ -225,12 +226,11 @@ class Client:
while True: while True:
time.sleep(wait) time.sleep(wait)
count += 1 count += 1
try: password = add_application_password(object_id)
return add_application_password(object_id) if password:
except CLIError as err: return password
if count > timeout_seconds/wait: if count > timeout_seconds/wait:
raise err raise Exception("creating password failed, trying again")
logger.info("creating password failed, trying again")
def setup_rbac(self): def setup_rbac(self):
""" """
@ -256,6 +256,25 @@ class Client:
logger.error("unable to query RBAC. Provide client_id and client_secret") logger.error("unable to query RBAC. Provide client_id and client_secret")
sys.exit(1) sys.exit(1)
app_roles = [
AppRole(
allowed_member_types=["Application"],
display_name="CliClient",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allows access from the CLI.",
value="CliClient",
),
AppRole(
allowed_member_types=["Application"],
display_name="ManagedNode",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allow access from a lab machine.",
value="ManagedNode",
),
]
if not existing: if not existing:
logger.info("creating Application registration") logger.info("creating Application registration")
url = "https://%s.azurewebsites.net" % self.application_name url = "https://%s.azurewebsites.net" % self.application_name
@ -273,24 +292,7 @@ class Client:
resource_app_id="00000002-0000-0000-c000-000000000000", resource_app_id="00000002-0000-0000-c000-000000000000",
) )
], ],
app_roles=[ app_roles=app_roles,
AppRole(
allowed_member_types=["Application"],
display_name="CliClient",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allows access from the CLI.",
value="CliClient",
),
AppRole(
allowed_member_types=["Application"],
display_name="LabMachine",
id=str(uuid.uuid4()),
is_enabled=True,
description="Allow access from a lab machine.",
value="LabMachine",
),
],
) )
app = client.applications.create(params) app = client.applications.create(params)
@ -303,7 +305,27 @@ class Client:
) )
client.service_principals.create(service_principal_params) client.service_principals.create(service_principal_params)
else: else:
app = existing[0] app: Application = existing[0]
existing_role_values = [app_role.value for app_role in app.app_roles]
has_missing_roles = any(
[role.value not in existing_role_values for role in app_roles]
)
if has_missing_roles:
# disabling the existing app role first to allow the update
# this is a requirement to update the application roles
for role in app.app_roles:
role.is_enabled = False
client.applications.patch(
app.object_id, ApplicationUpdateParameters(app_roles=app.app_roles)
)
# overriding the list of app roles
client.applications.patch(
app.object_id, ApplicationUpdateParameters(app_roles=app_roles)
)
creds = list(client.applications.list_password_credentials(app.object_id)) creds = list(client.applications.list_password_credentials(app.object_id))
client.applications.update_password_credentials(app.object_id, creds) client.applications.update_password_credentials(app.object_id, creds)
@ -612,7 +634,7 @@ class Client:
def update_registration(self): def update_registration(self):
if not self.create_registration: if not self.create_registration:
return return
update_registration(self.application_name) update_pool_registration(self.application_name)
def done(self): def done(self):
logger.info(TELEMETRY_NOTICE) logger.info(TELEMETRY_NOTICE)
@ -766,19 +788,6 @@ def main():
logging.getLogger("deploy").setLevel(logging.INFO) logging.getLogger("deploy").setLevel(logging.INFO)
# TODO: using az_cli resets logging defaults. For now, force these
# to be WARN level
if not args.verbose:
for entry in [
"adal-python",
"msrest.universal_http",
"urllib3.connectionpool",
"az_command_data_logger",
"msrest.service_client",
"azure.core.pipeline.policies.http_logging_policy",
]:
logging.getLogger(entry).setLevel(logging.WARN)
if args.start_at != states[0][0]: if args.start_at != states[0][0]:
logger.warning( logger.warning(
"*** Starting at a non-standard deployment state. " "*** Starting at a non-standard deployment state. "

View File

@ -4,15 +4,15 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import argparse import argparse
import json
import logging import logging
import os import urllib.parse
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, List, NamedTuple, Optional, Tuple from typing import Dict, List, NamedTuple, Optional, Tuple
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from azure.cli.core import get_default_cli # type: ignore import requests
from azure.common.client_factory import get_client_from_cli_profile from azure.common.client_factory import get_client_from_cli_profile
from azure.common.credentials import get_cli_profile
from azure.graphrbac import GraphRbacManagementClient from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import ( from azure.graphrbac.models import (
Application, Application,
@ -26,13 +26,41 @@ from msrest.serialization import TZ_UTC
logger = logging.getLogger("deploy") logger = logging.getLogger("deploy")
def az_cli(args): class GraphQueryError(Exception):
cli = get_default_cli() pass
cli.invoke(args, out_file=open(os.devnull, "w"))
if cli.result.result:
return cli.result.result def query_microsoft_graph(
elif cli.result.error: method: str,
raise cli.result.error resource: str,
params: Optional[Dict] = None,
body: Optional[Dict] = None,
):
profile = get_cli_profile()
(token_type, access_token, _), _, _ = profile.get_raw_token(
resource="https://graph.microsoft.com"
)
url = urllib.parse.urljoin("https://graph.microsoft.com/v1.0/", resource)
headers = {
"Authorization": "%s %s" % (token_type, access_token),
"Content-Type": "application/json",
}
response = requests.request(
method=method, url=url, headers=headers, params=params, json=body
)
response.status_code
if 200 <= response.status_code < 300:
try:
return response.json()
except ValueError:
return None
else:
error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
raise GraphQueryError(
"request did not succeed: HTTP %s - %s" % (response.status_code, error_text)
)
class ApplicationInfo(NamedTuple): class ApplicationInfo(NamedTuple):
@ -129,32 +157,22 @@ def create_application_registration(
) )
registered_app: Application = client.applications.create(params) registered_app: Application = client.applications.create(params)
body = {
query_microsoft_graph(
method="PATCH",
resource="applications/%s" % registered_app.object_id,
body={
"publicClient": { "publicClient": {
"redirectUris": ["https://%s.azurewebsites.net" % onefuzz_instance_name] "redirectUris": ["https://%s.azurewebsites.net" % onefuzz_instance_name]
}, },
"isFallbackPublicClient": True, "isFallbackPublicClient": True,
} },
az_cli(
[
"rest",
"-m",
"PATCH",
"-u",
"https://graph.microsoft.com/v1.0/applications/%s"
% registered_app.object_id,
"--headers",
"Content-Type=application/json",
"-b",
json.dumps(body),
]
) )
authorize_application(UUID(registered_app.app_id), UUID(app.app_id)) authorize_application(UUID(registered_app.app_id), UUID(app.app_id))
return registered_app return registered_app
def add_application_password(app_object_id: UUID) -> Tuple[str, str]: def add_application_password(app_object_id: UUID) -> Optional[Tuple[str, str]]:
key = uuid4() key = uuid4()
password_request = { password_request = {
"passwordCredential": { "passwordCredential": {
@ -166,36 +184,25 @@ def add_application_password(app_object_id: UUID) -> Tuple[str, str]:
), ),
} }
} }
try:
password: Dict = az_cli( password: Dict = query_microsoft_graph(
[ method="POST",
"rest", resource="applications/%s/addPassword" % app_object_id,
"-m", body=password_request,
"POST",
"-u",
"https://graph.microsoft.com/v1.0/applications/%s/addPassword"
% app_object_id,
"-b",
json.dumps(password_request),
]
) )
return (str(key), password["secretText"]) return (str(key), password["secretText"])
except GraphQueryError as err:
logger.warning("creating password failed : %s" % err)
None
def get_application(app_id: UUID) -> Optional[Dict]: def get_application(app_id: UUID) -> Optional[Dict]:
apps: Dict = az_cli( apps: Dict = query_microsoft_graph(
[ method="GET",
"rest", resource="applications",
"-m", params={"$filter": "appId eq '%s'" % app_id},
"GET",
"-u",
"https://graph.microsoft.com/v1.0/applications",
"--uri-parameters",
"$filter=appId eq '%s'" % app_id,
]
) )
if len(apps["value"]) == 0: if len(apps["value"]) == 0:
return None return None
@ -234,24 +241,16 @@ def authorize_application(
.map(lambda data: {"appId": data[0], "delegatedPermissionIds": data[1]}) .map(lambda data: {"appId": data[0], "delegatedPermissionIds": data[1]})
) )
body = {"api": {"preAuthorizedApplications": preAuthorizedApplications.to_list()}} query_microsoft_graph(
method="PATCH",
az_cli( resource="applications/%s" % onefuzz_app["id"],
[ body={
"rest", "api": {"preAuthorizedApplications": preAuthorizedApplications.to_list()}
"-m", },
"PATCH",
"-u",
"https://graph.microsoft.com/v1.0/applications/%s" % onefuzz_app["id"],
"--headers",
"Content-Type=application/json",
"-b",
json.dumps(body),
]
) )
def update_registration(application_name: str): def update_pool_registration(application_name: str):
logger.info("Updating application registration") logger.info("Updating application registration")
application_info = register_application( application_info = register_application(
@ -264,8 +263,72 @@ def update_registration(application_name: str):
logger.info("client_secret: %s" % application_info.client_secret) logger.info("client_secret: %s" % application_info.client_secret)
def assign_scaleset_role(onefuzz_instance_name: str, scaleset_name: str):
""" Allows the nodes in the scaleset to access the service by assigning their managed identity to the ManagedNode Role """
onefuzz_service_appId = query_microsoft_graph(
method="GET",
resource="applications",
params={
"$filter": "displayName eq '%s'" % onefuzz_instance_name,
"$select": "appId",
},
)
if len(onefuzz_service_appId["value"]) == 0:
raise Exception("onefuzz app registration not found")
appId = onefuzz_service_appId["value"][0]["appId"]
onefuzz_service_principals = query_microsoft_graph(
method="GET",
resource="servicePrincipals",
params={"$filter": "appId eq '%s'" % appId},
)
if len(onefuzz_service_principals["value"]) == 0:
raise Exception("onefuzz app service principal not found")
onefuzz_service_principal = onefuzz_service_principals["value"][0]
scaleset_service_principals = query_microsoft_graph(
method="GET",
resource="servicePrincipals",
params={"$filter": "displayName eq '%s'" % scaleset_name},
)
if len(scaleset_service_principals["value"]) == 0:
raise Exception("scaleset service principal not found")
scaleset_service_principal = scaleset_service_principals["value"][0]
lab_machine_role = (
seq(onefuzz_service_principal["appRoles"])
.filter(lambda x: x["value"] == "ManagedNode")
.head_option()
)
if not lab_machine_role:
raise Exception(
"ManagedNode role not found int the onefuzz application registration. Please redeploy the instance"
)
query_microsoft_graph(
method="POST",
resource="servicePrincipals/%s/appRoleAssignedTo"
% scaleset_service_principal["id"],
body={
"principalId": scaleset_service_principal["id"],
"resourceId": onefuzz_service_principal["id"],
"appRoleId": lab_machine_role["id"],
},
)
def main(): def main():
formatter = argparse.ArgumentDefaultsHelpFormatter formatter = argparse.ArgumentDefaultsHelpFormatter
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument(
"onefuzz_instance", help="the name of the onefuzz instance"
)
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=formatter, formatter_class=formatter,
description=( description=(
@ -273,8 +336,19 @@ def main():
"generate a password for the pool agent" "generate a password for the pool agent"
), ),
) )
parser.add_argument("application_name")
parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("-v", "--verbose", action="store_true")
subparsers = parser.add_subparsers(title="commands", dest="command")
subparsers.add_parser("update_pool_registration", parents=[parent_parser])
role_assignment_parser = subparsers.add_parser(
"assign_scaleset_role",
parents=[parent_parser],
)
role_assignment_parser.add_argument(
"scaleset_name",
help="the name of the scaleset",
)
args = parser.parse_args() args = parser.parse_args()
if args.verbose: if args.verbose:
level = logging.DEBUG level = logging.DEBUG
@ -284,7 +358,12 @@ def main():
logging.basicConfig(format="%(levelname)s:%(message)s", level=level) logging.basicConfig(format="%(levelname)s:%(message)s", level=level)
logging.getLogger("deploy").setLevel(logging.INFO) logging.getLogger("deploy").setLevel(logging.INFO)
update_registration(args.application_name) if args.command == "update_pool_registration":
update_pool_registration(args.onefuzz_instance)
elif args.command == "assign_scaleset_role":
assign_scaleset_role(args.onefuzz_instance, args.scaleset_name)
else:
raise Exception("invalid arguments")
if __name__ == "__main__": if __name__ == "__main__":