Merge pull request from GHSA-q5vh-6whw-x745

* verify aad tenants, primarily needed in multi-tenant deployments

* add logging and fix trailing slash for issuer

* handle call_if* not supporting additional argument callbacks

* add logging

* include new datatype in webhook docs

* fix pytypes unit tests

Co-authored-by: Brian Caswell <bmc@shmoo.com>
This commit is contained in:
bmc-msft
2021-08-13 14:50:54 -04:00
committed by GitHub
parent ba3a6eab04
commit 2fcb499888
12 changed files with 193 additions and 31 deletions

View File

@ -641,7 +641,10 @@ Each event will be submitted via HTTP POST to the user provided URL.
"admins": [ "admins": [
"00000000-0000-0000-0000-000000000000" "00000000-0000-0000-0000-000000000000"
], ],
"allow_pool_management": true "allow_pool_management": true,
"allowed_aad_tenants": [
"00000000-0000-0000-0000-000000000000"
]
} }
} }
``` ```
@ -665,8 +668,19 @@ Each event will be submitted via HTTP POST to the user provided URL.
"default": true, "default": true,
"title": "Allow Pool Management", "title": "Allow Pool Management",
"type": "boolean" "type": "boolean"
},
"allowed_aad_tenants": {
"items": {
"format": "uuid",
"type": "string"
},
"title": "Allowed Aad Tenants",
"type": "array"
} }
}, },
"required": [
"allowed_aad_tenants"
],
"title": "InstanceConfig", "title": "InstanceConfig",
"type": "object" "type": "object"
} }
@ -5599,8 +5613,19 @@ Each event will be submitted via HTTP POST to the user provided URL.
"default": true, "default": true,
"title": "Allow Pool Management", "title": "Allow Pool Management",
"type": "boolean" "type": "boolean"
},
"allowed_aad_tenants": {
"items": {
"format": "uuid",
"type": "string"
},
"title": "Allowed Aad Tenants",
"type": "array"
} }
}, },
"required": [
"allowed_aad_tenants"
],
"title": "InstanceConfig", "title": "InstanceConfig",
"type": "object" "type": "object"
}, },

View File

@ -14,11 +14,12 @@ from ..onefuzzlib.azure.creds import (
get_instance_id, get_instance_id,
get_subscription, get_subscription,
) )
from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import ok from ..onefuzzlib.request import ok
from ..onefuzzlib.versions import versions from ..onefuzzlib.versions import versions
def main(req: func.HttpRequest) -> func.HttpResponse: def get(req: func.HttpRequest) -> func.HttpResponse:
response = ok( response = ok(
Info( Info(
resource_group=get_base_resource_group(), resource_group=get_base_resource_group(),
@ -32,3 +33,11 @@ def main(req: func.HttpRequest) -> func.HttpResponse:
) )
return response return response
def main(req: func.HttpRequest) -> func.HttpResponse:
methods = {"GET": get}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -5,6 +5,8 @@
import azure.functions as func import azure.functions as func
from ..onefuzzlib.endpoint_authorization import call_if_user
# This endpoint handles the signalr negotation # This endpoint handles the signalr negotation
# As we do not differentiate from clients at this time, we pass the Functions runtime # As we do not differentiate from clients at this time, we pass the Functions runtime
# provided connection straight to the client # provided connection straight to the client
@ -14,8 +16,19 @@ import azure.functions as func
def main(req: func.HttpRequest, connectionInfoJson: str) -> func.HttpResponse: def main(req: func.HttpRequest, connectionInfoJson: str) -> func.HttpResponse:
# NOTE: this is a sub-method because the call_if* do not support callbacks with
# additional arguments at this time. Once call_if* supports additional arguments,
# this should be made a generic function
def post(req: func.HttpRequest) -> func.HttpResponse:
return func.HttpResponse( return func.HttpResponse(
connectionInfoJson, connectionInfoJson,
status_code=200, status_code=200,
headers={"Content-type": "application/json"}, headers={"Content-type": "application/json"},
) )
methods = {"POST": post}
method = methods[req.method]
result = call_if_user(req, method)
return result

View File

@ -25,7 +25,7 @@ class InstanceConfig(BASE_CONFIG, ORMMixin):
def fetch(cls) -> "InstanceConfig": def fetch(cls) -> "InstanceConfig":
entry = cls.get(get_instance_name()) entry = cls.get(get_instance_name())
if entry is None: if entry is None:
entry = cls() entry = cls(allowed_aad_tenants=[])
entry.save() entry.save()
return entry return entry

View File

@ -3,14 +3,18 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from typing import Optional import logging
from typing import List, Optional
from uuid import UUID from uuid import UUID
import azure.functions as func import azure.functions as func
import jwt import jwt
from memoization import cached
from onefuzztypes.enums import ErrorCode from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Result, UserInfo from onefuzztypes.models import Error, Result, UserInfo
from .config import InstanceConfig
def get_bearer_token(request: func.HttpRequest) -> Optional[str]: def get_bearer_token(request: func.HttpRequest) -> Optional[str]:
auth: str = request.headers.get("Authorization", None) auth: str = request.headers.get("Authorization", None)
@ -39,6 +43,13 @@ def get_auth_token(request: func.HttpRequest) -> Optional[str]:
return str(token_header) return str(token_header)
@cached(ttl=60)
def get_allowed_tenants() -> List[str]:
config = InstanceConfig.fetch()
entries = [f"https://sts.windows.net/{x}/" for x in config.allowed_aad_tenants]
return entries
def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]: def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
"""Obtains the Access Token from the Authorization Header""" """Obtains the Access Token from the Authorization Header"""
token_str = get_auth_token(request) token_str = get_auth_token(request)
@ -48,9 +59,20 @@ def parse_jwt_token(request: func.HttpRequest) -> Result[UserInfo]:
errors=["unable to find authorization token"], errors=["unable to find authorization token"],
) )
# This token has already been verified by the azure authentication layer # The JWT token has already been verified by the azure authentication layer,
# but we need to verify the tenant is as we expect.
token = jwt.decode(token_str, options={"verify_signature": False}) token = jwt.decode(token_str, options={"verify_signature": False})
if "iss" not in token:
return Error(
code=ErrorCode.INVALID_REQUEST, errors=["missing issuer from token"]
)
tenants = get_allowed_tenants()
if token["iss"] not in tenants:
logging.error("issuer not from allowed tenant: %s - %s", token["iss"], tenants)
return Error(code=ErrorCode.INVALID_REQUEST, errors=["unauthorized AAD issuer"])
application_id = UUID(token["appid"]) if "appid" in token else None application_id = UUID(token["appid"]) if "appid" in token else None
object_id = UUID(token["oid"]) if "oid" in token else None object_id = UUID(token["oid"]) if "oid" in token else None
upn = token.get("upn") upn = token.get("upn")

View File

@ -5,7 +5,7 @@
import os import os
import unittest import unittest
from uuid import uuid4 from uuid import UUID, uuid4
from onefuzztypes.models import UserInfo from onefuzztypes.models import UserInfo
@ -25,29 +25,41 @@ class TestAdmin(unittest.TestCase):
user2 = uuid4() user2 = uuid4()
# no admins set # no admins set
self.assertTrue(can_modify_config_impl(InstanceConfig(), UserInfo())) self.assertTrue(
can_modify_config_impl(
InstanceConfig(allowed_aad_tenants=[UUID(int=0)]), UserInfo()
)
)
# with oid, but no admin # with oid, but no admin
self.assertTrue( self.assertTrue(
can_modify_config_impl(InstanceConfig(), UserInfo(object_id=user1)) can_modify_config_impl(
InstanceConfig(allowed_aad_tenants=[UUID(int=0)]),
UserInfo(object_id=user1),
)
) )
# is admin # is admin
self.assertTrue( self.assertTrue(
can_modify_config_impl( can_modify_config_impl(
InstanceConfig(admins=[user1]), UserInfo(object_id=user1) InstanceConfig(allowed_aad_tenants=[UUID(int=0)], admins=[user1]),
UserInfo(object_id=user1),
) )
) )
# no user oid set # no user oid set
self.assertFalse( self.assertFalse(
can_modify_config_impl(InstanceConfig(admins=[user1]), UserInfo()) can_modify_config_impl(
InstanceConfig(allowed_aad_tenants=[UUID(int=0)], admins=[user1]),
UserInfo(),
)
) )
# not an admin # not an admin
self.assertFalse( self.assertFalse(
can_modify_config_impl( can_modify_config_impl(
InstanceConfig(admins=[user1]), UserInfo(object_id=user2) InstanceConfig(allowed_aad_tenants=[UUID(int=0)], admins=[user1]),
UserInfo(object_id=user2),
) )
) )
@ -58,21 +70,31 @@ class TestAdmin(unittest.TestCase):
# by default, any can modify # by default, any can modify
self.assertIsNone( self.assertIsNone(
check_can_manage_pools_impl( check_can_manage_pools_impl(
InstanceConfig(allow_pool_management=True), UserInfo() InstanceConfig(
allowed_aad_tenants=[UUID(int=0)], allow_pool_management=True
),
UserInfo(),
) )
) )
# with oid, but no admin # with oid, but no admin
self.assertIsNone( self.assertIsNone(
check_can_manage_pools_impl( check_can_manage_pools_impl(
InstanceConfig(allow_pool_management=True), UserInfo(object_id=user1) InstanceConfig(
allowed_aad_tenants=[UUID(int=0)], allow_pool_management=True
),
UserInfo(object_id=user1),
) )
) )
# is admin # is admin
self.assertIsNone( self.assertIsNone(
check_can_manage_pools_impl( check_can_manage_pools_impl(
InstanceConfig(allow_pool_management=False, admins=[user1]), InstanceConfig(
allowed_aad_tenants=[UUID(int=0)],
allow_pool_management=False,
admins=[user1],
),
UserInfo(object_id=user1), UserInfo(object_id=user1),
) )
) )
@ -80,14 +102,23 @@ class TestAdmin(unittest.TestCase):
# no user oid set # no user oid set
self.assertIsNotNone( self.assertIsNotNone(
check_can_manage_pools_impl( check_can_manage_pools_impl(
InstanceConfig(allow_pool_management=False, admins=[user1]), UserInfo() InstanceConfig(
allowed_aad_tenants=[UUID(int=0)],
allow_pool_management=False,
admins=[user1],
),
UserInfo(),
) )
) )
# not an admin # not an admin
self.assertIsNotNone( self.assertIsNotNone(
check_can_manage_pools_impl( check_can_manage_pools_impl(
InstanceConfig(allow_pool_management=False, admins=[user1]), InstanceConfig(
allowed_aad_tenants=[UUID(int=0)],
allow_pool_management=False,
admins=[user1],
),
UserInfo(object_id=user2), UserInfo(object_id=user2),
) )
) )

View File

@ -821,6 +821,10 @@
"scaleset-identity": { "scaleset-identity": {
"type": "string", "type": "string",
"value": "[variables('scaleset_identity')]" "value": "[variables('scaleset_identity')]"
},
"tenant_id": {
"type": "string",
"value": "[subscription().tenantId]"
} }
} }
} }

View File

@ -72,7 +72,7 @@ from registration import (
set_app_audience, set_app_audience,
update_pool_registration, update_pool_registration,
) )
from set_admins import update_admins from set_admins import update_admins, update_allowed_aad_tenants
# Found by manually assigning the User.Read permission to application # Found by manually assigning the User.Read permission to application
# registration in the admin portal. The values are in the manifest under # registration in the admin portal. The values are in the manifest under
@ -130,7 +130,8 @@ class Client:
multi_tenant_domain: str, multi_tenant_domain: str,
upgrade: bool, upgrade: bool,
subscription_id: Optional[str], subscription_id: Optional[str],
admins: List[UUID] admins: List[UUID],
allowed_aad_tenants: List[UUID],
): ):
self.subscription_id = subscription_id self.subscription_id = subscription_id
self.resource_group = resource_group self.resource_group = resource_group
@ -161,6 +162,7 @@ class Client:
self.export_appinsights = export_appinsights self.export_appinsights = export_appinsights
self.log_service_principal = log_service_principal self.log_service_principal = log_service_principal
self.admins = admins self.admins = admins
self.allowed_aad_tenants = allowed_aad_tenants
machine = platform.machine() machine = platform.machine()
system = platform.system() system = platform.system()
@ -560,13 +562,20 @@ class Client:
table_service = TableService(account_name=name, account_key=key) table_service = TableService(account_name=name, account_key=key)
migrate(table_service, self.migrations) migrate(table_service, self.migrations)
def set_admins(self) -> None: def set_instance_config(self) -> None:
name = self.results["deploy"]["func-name"]["value"] name = self.results["deploy"]["func-name"]["value"]
key = self.results["deploy"]["func-key"]["value"] key = self.results["deploy"]["func-key"]["value"]
tenant = UUID(self.results["deploy"]["tenant_id"]["value"])
table_service = TableService(account_name=name, account_key=key) table_service = TableService(account_name=name, account_key=key)
if self.admins: if self.admins:
update_admins(table_service, self.application_name, self.admins) update_admins(table_service, self.application_name, self.admins)
tenants = self.allowed_aad_tenants
if tenant not in tenants:
tenants.append(tenant)
update_allowed_aad_tenants(table_service, self.application_name, tenants)
def create_queues(self) -> None: def create_queues(self) -> None:
logger.info("creating eventgrid destination queue") logger.info("creating eventgrid destination queue")
@ -926,7 +935,7 @@ def main() -> None:
full_deployment_states = rbac_only_states + [ full_deployment_states = rbac_only_states + [
("apply_migrations", Client.apply_migrations), ("apply_migrations", Client.apply_migrations),
("set_admins", Client.set_admins), ("set_instance_config", Client.set_instance_config),
("queues", Client.create_queues), ("queues", Client.create_queues),
("eventgrid", Client.create_eventgrid), ("eventgrid", Client.create_eventgrid),
("tools", Client.upload_tools), ("tools", Client.upload_tools),
@ -1038,6 +1047,12 @@ def main() -> None:
nargs="*", nargs="*",
help="set the list of administrators (by OID in AAD)", help="set the list of administrators (by OID in AAD)",
) )
parser.add_argument(
"--allowed_aad_tenants",
type=UUID,
nargs="*",
help="Set additional AAD tenants beyond the tenant the app is deployed in",
)
args = parser.parse_args() args = parser.parse_args()
@ -1066,6 +1081,7 @@ def main() -> None:
upgrade=args.upgrade, upgrade=args.upgrade,
subscription_id=args.subscription_id, subscription_id=args.subscription_id,
admins=args.set_admins, admins=args.set_admins,
allowed_aad_tenants=args.allowed_aad_tenants or [],
) )
if args.verbose: if args.verbose:
level = logging.DEBUG level = logging.DEBUG

View File

@ -20,6 +20,21 @@ def create_if_missing(table_service: TableService) -> None:
table_service.create_table(TABLE_NAME) table_service.create_table(TABLE_NAME)
def update_allowed_aad_tenants(
table_service: TableService, resource_group: str, tenants: List[UUID]
) -> None:
create_if_missing(table_service)
as_str = [str(x) for x in tenants]
table_service.insert_or_merge_entity(
TABLE_NAME,
{
"PartitionKey": resource_group,
"RowKey": resource_group,
"allowed_aad_tenants": json.dumps(as_str),
},
)
def update_admins( def update_admins(
table_service: TableService, resource_group: str, admins: List[UUID] table_service: TableService, resource_group: str, admins: List[UUID]
) -> None: ) -> None:
@ -43,7 +58,8 @@ def main() -> None:
parser = argparse.ArgumentParser(formatter_class=formatter) parser = argparse.ArgumentParser(formatter_class=formatter)
parser.add_argument("resource_group") parser.add_argument("resource_group")
parser.add_argument("storage_account") parser.add_argument("storage_account")
parser.add_argument("admins", type=UUID, nargs="*") parser.add_argument("--admins", type=UUID, nargs="*")
parser.add_argument("--allowed_aad_tenants", type=UUID, nargs="*")
args = parser.parse_args() args = parser.parse_args()
client = get_client_from_cli_profile(StorageManagementClient) client = get_client_from_cli_profile(StorageManagementClient)
@ -53,7 +69,12 @@ def main() -> None:
table_service = TableService( table_service = TableService(
account_name=args.storage_account, account_key=storage_keys.keys[0].value account_name=args.storage_account, account_key=storage_keys.keys[0].value
) )
if args.admins:
update_admins(table_service, args.resource_group, args.admins) update_admins(table_service, args.resource_group, args.admins)
if args.allowed_aad_tenants:
update_allowed_aad_tenants(
table_service, args.resource_group, args.allowed_aad_tenants
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -258,7 +258,11 @@ def main() -> None:
EventFileAdded(container=Container("container-name"), filename="example.txt"), EventFileAdded(container=Container("container-name"), filename="example.txt"),
EventNodeHeartbeat(machine_id=UUID(int=0), pool_name=PoolName("example")), EventNodeHeartbeat(machine_id=UUID(int=0), pool_name=PoolName("example")),
EventTaskHeartbeat(task_id=UUID(int=0), job_id=UUID(int=0), config=task_config), EventTaskHeartbeat(task_id=UUID(int=0), job_id=UUID(int=0), config=task_config),
EventInstanceConfigUpdated(config=InstanceConfig(admins=[UUID(int=0)])), EventInstanceConfigUpdated(
config=InstanceConfig(
admins=[UUID(int=0)], allowed_aad_tenants=[UUID(int=0)]
)
),
] ]
# works around `mypy` not handling that Union has `__args__` # works around `mypy` not handling that Union has `__args__`

View File

@ -802,6 +802,8 @@ class InstanceConfig(BaseModel):
# if set, only admins can manage pools or scalesets # if set, only admins can manage pools or scalesets
allow_pool_management: bool = Field(default=True) allow_pool_management: bool = Field(default=True)
allowed_aad_tenants: List[UUID]
def update(self, config: "InstanceConfig") -> None: def update(self, config: "InstanceConfig") -> None:
for field in config.__fields__: for field in config.__fields__:
# If no admins are set, then ignore setting admins # If no admins are set, then ignore setting admins
@ -817,5 +819,16 @@ class InstanceConfig(BaseModel):
raise ValueError("admins must be None or contain at least one UUID") raise ValueError("admins must be None or contain at least one UUID")
return value return value
# At the moment, this only checks allowed_aad_tenants, however adding
# support for 3rd party JWT validation is anticipated in a future release.
@root_validator()
def check_instance_config(cls, values: Any) -> Any:
if "allowed_aad_tenants" not in values:
raise ValueError("missing allowed_aad_tenants")
if not len(values["allowed_aad_tenants"]):
raise ValueError("allowed_aad_tenants must not be empty")
return values
_check_hotfix() _check_hotfix()

View File

@ -11,9 +11,13 @@ from onefuzztypes.models import InstanceConfig
class TestInstanceConfig(unittest.TestCase): class TestInstanceConfig(unittest.TestCase):
def test_with_admins(self) -> None: def test_with_admins(self) -> None:
no_admins = InstanceConfig(admins=None) no_admins = InstanceConfig(admins=None, allowed_aad_tenants=[UUID(int=0)])
with_admins = InstanceConfig(admins=[UUID(int=0)]) with_admins = InstanceConfig(
with_admins_2 = InstanceConfig(admins=[UUID(int=1)]) admins=[UUID(int=0)], allowed_aad_tenants=[UUID(int=0)]
)
with_admins_2 = InstanceConfig(
admins=[UUID(int=1)], allowed_aad_tenants=[UUID(int=0)]
)
no_admins.update(with_admins) no_admins.update(with_admins)
self.assertEqual(no_admins.admins, None) self.assertEqual(no_admins.admins, None)