Group membership check (#1074)

This commit is contained in:
Cheick Keita
2021-11-22 14:06:03 -08:00
committed by GitHub
parent 8ff6509e61
commit aa74550160
12 changed files with 429 additions and 78 deletions

View File

@ -233,7 +233,7 @@ jobs:
pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl
pip install -r __app__/requirements.txt
pip install -r requirements-dev.txt
pytest
pytest tests
flake8 .
bandit -r ./__app__/
black ./__app__/ ./tests --check

View File

@ -663,6 +663,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json
{
"definitions": {
"ApiAccessRule": {
"properties": {
"allowed_groups": {
"items": {
"format": "uuid",
"type": "string"
},
"title": "Allowed Groups",
"type": "array"
},
"methods": {
"items": {
"type": "string"
},
"title": "Methods",
"type": "array"
}
},
"required": [
"methods",
"allowed_groups"
],
"title": "ApiAccessRule",
"type": "object"
},
"AzureMonitorExtensionConfig": {
"properties": {
"config_version": {
@ -757,9 +782,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Allowed Aad Tenants",
"type": "array"
},
"api_access_rules": {
"additionalProperties": {
"$ref": "#/definitions/ApiAccessRule"
},
"title": "Api Access Rules",
"type": "object"
},
"extensions": {
"$ref": "#/definitions/AzureVmExtensionConfig"
},
"group_membership": {
"additionalProperties": {
"items": {
"format": "uuid",
"type": "string"
},
"type": "array"
},
"title": "Group Membership",
"type": "object"
},
"network_config": {
"$ref": "#/definitions/NetworkConfig"
},
@ -4933,6 +4976,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json
{
"definitions": {
"ApiAccessRule": {
"properties": {
"allowed_groups": {
"items": {
"format": "uuid",
"type": "string"
},
"title": "Allowed Groups",
"type": "array"
},
"methods": {
"items": {
"type": "string"
},
"title": "Methods",
"type": "array"
}
},
"required": [
"methods",
"allowed_groups"
],
"title": "ApiAccessRule",
"type": "object"
},
"Architecture": {
"description": "An enumeration.",
"enum": [
@ -5856,9 +5924,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Allowed Aad Tenants",
"type": "array"
},
"api_access_rules": {
"additionalProperties": {
"$ref": "#/definitions/ApiAccessRule"
},
"title": "Api Access Rules",
"type": "object"
},
"extensions": {
"$ref": "#/definitions/AzureVmExtensionConfig"
},
"group_membership": {
"additionalProperties": {
"items": {
"format": "uuid",
"type": "string"
},
"type": "array"
},
"title": "Group Membership",
"type": "object"
},
"network_config": {
"$ref": "#/definitions/NetworkConfig"
},

View File

@ -168,14 +168,6 @@ def query_microsoft_graph_list(
raise GraphQueryError("Expected data containing a list of values", None)
def is_member_of(group_id: str, member_id: str) -> bool:
body = {"groupIds": [group_id]}
response = query_microsoft_graph_list(
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
)
return group_id in response
@cached
def get_scaleset_identity_resource_path() -> str:
scaleset_id_name = "%s-scalesetid" % get_instance_name()

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Protocol
from uuid import UUID
from ..config import InstanceConfig
from .creds import query_microsoft_graph_list
class GroupMembershipChecker(Protocol):
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
"""Check if member is part of at least one of the groups"""
if member_id in group_ids:
return True
groups = self.get_groups(member_id)
for g in group_ids:
if g in groups:
return True
return False
def get_groups(self, member_id: UUID) -> List[UUID]:
"""Gets all the groups of the provided member"""
def create_group_membership_checker() -> GroupMembershipChecker:
config = InstanceConfig.fetch()
if config.group_membership:
return StaticGroupMembership(config.group_membership)
else:
return AzureADGroupMembership()
class AzureADGroupMembership(GroupMembershipChecker):
def get_groups(self, member_id: UUID) -> List[UUID]:
response = query_microsoft_graph_list(
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
)
return response
class StaticGroupMembership(GroupMembershipChecker):
def __init__(self, memberships: Dict[str, List[UUID]]):
self.memberships = memberships
def get_groups(self, member_id: UUID) -> List[UUID]:
return self.memberships.get(str(member_id), [])

View File

@ -5,37 +5,76 @@
import json
import logging
import os
import urllib
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
from uuid import UUID
from azure.functions import HttpRequest, HttpResponse
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.responses import BaseResponse
from pydantic import BaseModel # noqa: F401
from pydantic import ValidationError
from .azure.group_membership import create_group_membership_checker
from .config import InstanceConfig
from .orm import ModelMixin
from .request_access import RequestAccess
# We don't actually use these types at runtime at this time. Rather,
# these are used in a bound TypeVar. MyPy suggests to only import these
# types during type checking.
if TYPE_CHECKING:
from onefuzztypes.requests import BaseRequest # noqa: F401
from pydantic import BaseModel # noqa: F401
@cached(ttl=60)
def get_rules() -> Optional[RequestAccess]:
config = InstanceConfig.fetch()
if config.api_access_rules:
return RequestAccess.build(config.api_access_rules)
else:
return None
def check_access(req: HttpRequest) -> Optional[Error]:
if "ONEFUZZ_AAD_GROUP_ID" in os.environ:
message = "ONEFUZZ_AAD_GROUP_ID configuration not supported"
logging.error(message)
return Error(
code=ErrorCode.INVALID_CONFIGURATION,
errors=[message],
)
else:
rules = get_rules()
# Noting to enforce if there are no rules.
if not rules:
return None
path = urllib.parse.urlparse(req.url).path
rule = rules.get_matching_rules(req.method, path)
# No restriction defined on this endpoint.
if not rule:
return None
member_id = UUID(req.headers["x-ms-client-principal-id"])
try:
membership_checker = create_group_membership_checker()
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
if not allowed:
logging.error(
"unauthorized access: %s is not authorized to access in %s",
member_id,
req.url,
)
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["not approved to use this endpoint"],
)
except Exception as e:
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["unable to interact with graph", str(e)],
)
return None
def ok(
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]

View File

@ -1,8 +1,7 @@
from typing import Dict, List
from typing import Dict, List, Optional
from uuid import UUID
from onefuzztypes.models import ApiAccessRule
from pydantic import parse_raw_as
class RuleConflictError(Exception):
@ -41,7 +40,7 @@ class RequestAccess:
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
methods = list(map(lambda m: m.upper(), methods))
segments = path.split("/")
segments = [s for s in path.split("/") if s != ""]
if len(segments) == 0:
return
@ -71,15 +70,14 @@ class RequestAccess:
for method in methods:
current_node.rules[method] = rules
def get_matching_rules(self, method: str, path: str) -> Rules:
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
method = method.upper()
segments = path.split("/")
segments = [s for s in path.split("/") if s != ""]
current_node = self.root
current_rule = None
if method in current_node.rules:
current_rule = current_node.rules[method]
else:
current_rule = RequestAccess.Rules()
current_segment_index = 0
@ -98,17 +96,13 @@ class RequestAccess:
return current_rule
@classmethod
def parse_rules(cls, rules_data: str) -> "RequestAccess":
rules = parse_raw_as(List[ApiAccessRule], rules_data)
return cls.build(rules)
@classmethod
def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess":
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
request_access = RequestAccess()
for rule in rules:
for endpoint in rules:
rule = rules[endpoint]
request_access.__add_url__(
rule.methods,
rule.endpoint,
endpoint,
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
)

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import time
import uuid
from typing import Any, List
from urllib.parse import urlparse
from uuid import UUID
from azure.cli.core import get_default_cli
from onefuzz.api import Onefuzz
from onefuzztypes.models import ApiAccessRule
def az_cli(args: List[str]) -> Any:
cli = get_default_cli()
cli.logging_cls
cli.invoke(args, out_file=open(os.devnull, "w"))
if cli.result.result:
return cli.result.result
elif cli.result.error:
raise cli.result.error
class APIRestrictionTests:
def __init__(
self, resource_group: str = None, onefuzz_config_path: str = None
) -> None:
self.onefuzz = Onefuzz(config_path=onefuzz_config_path)
self.intial_config = self.onefuzz.instance_config.get()
self.instance_name = urlparse(self.onefuzz.config().endpoint).netloc.split(".")[
0
]
if resource_group:
self.resource_group = resource_group
else:
self.resource_group = self.instance_name
def restore_config(self) -> None:
self.onefuzz.instance_config.update(self.intial_config)
def assign(self, group_id: UUID, member_id: UUID) -> None:
instance_config = self.onefuzz.instance_config.get()
if instance_config.group_membership is None:
instance_config.group_membership = {}
if member_id not in instance_config.group_membership:
instance_config.group_membership[member_id] = []
if group_id not in instance_config.group_membership[member_id]:
instance_config.group_membership[member_id].append(group_id)
self.onefuzz.instance_config.update(instance_config)
def assign_current_user(self, group_id: UUID) -> None:
onefuzz_service_appId = az_cli(
[
"ad",
"signed-in-user",
"show",
]
)
member_id = UUID(onefuzz_service_appId["objectId"])
print(f"adding user {member_id}")
self.assign(group_id, member_id)
def test_restriction_on_current_user(self) -> None:
print("Checking that the current user can get jobs")
self.onefuzz.jobs.list()
print("Creating test group")
group_id = uuid.uuid4()
print("Adding restriction to the jobs endpoint")
instance_config = self.onefuzz.instance_config.get()
if instance_config.api_access_rules is None:
instance_config.api_access_rules = {}
instance_config.api_access_rules["/api/jobs"] = ApiAccessRule(
allowed_groups=[group_id],
methods=["GET"],
)
self.onefuzz.instance_config.update(instance_config)
restart_instance(self.instance_name, self.resource_group)
time.sleep(20)
print("Checking that the current user cannot get jobs")
try:
self.onefuzz.jobs.list()
failed = False
except Exception:
failed = True
pass
if not failed:
raise Exception("Current user was able to get jobs")
print("Assigning current user to test group")
self.assign_current_user(group_id)
restart_instance(self.instance_name, self.resource_group)
time.sleep(20)
print("Checking that the current user can get jobs")
self.onefuzz.jobs.list()
def restart_instance(instance_name: str, resource_group: str) -> None:
print("Restarting instance")
az_cli(
[
"functionapp",
"restart",
"--name",
f"{instance_name}",
"--resource-group",
f"{resource_group}",
]
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", default=None)
parser.add_argument("--resource_group", default=None)
args = parser.parse_args()
tester = APIRestrictionTests(args.resource_group, args.config_path)
try:
print("test current user restriction")
tester.test_restriction_on_current_user()
finally:
tester.restore_config()
pass
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
../../cli
../../pytypes
azure-cli-core==2.27.2
azure-cli==2.27.2

View File

@ -0,0 +1,38 @@
import unittest
import uuid
from __app__.onefuzzlib.azure.group_membership import (
GroupMembershipChecker,
StaticGroupMembership,
)
class TestRequestAccess(unittest.TestCase):
def test_empty(self) -> None:
group_id = uuid.uuid4()
user_id = uuid.uuid4()
checker: GroupMembershipChecker = StaticGroupMembership({})
self.assertFalse(checker.is_member([group_id], user_id))
self.assertTrue(checker.is_member([user_id], user_id))
def test_matching_user_id(self) -> None:
group_id = uuid.uuid4()
user_id1 = uuid.uuid4()
user_id2 = uuid.uuid4()
checker: GroupMembershipChecker = StaticGroupMembership(
{str(user_id1): [group_id]}
)
self.assertTrue(checker.is_member([user_id1], user_id1))
self.assertFalse(checker.is_member([user_id1], user_id2))
def test_user_in_group(self) -> None:
group_id1 = uuid.uuid4()
group_id2 = uuid.uuid4()
user_id = uuid.uuid4()
checker: GroupMembershipChecker = StaticGroupMembership(
{str(user_id): [group_id1]}
)
self.assertTrue(checker.is_member([group_id1], user_id))
self.assertFalse(checker.is_member([group_id2], user_id))

View File

@ -8,61 +8,60 @@ from __app__.onefuzzlib.request_access import RequestAccess, RuleConflictError
class TestRequestAccess(unittest.TestCase):
def test_empty(self) -> None:
request_access1 = RequestAccess.build([])
request_access1 = RequestAccess.build({})
rules1 = request_access1.get_matching_rules("get", "a/b/c")
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing")
self.assertEqual(rules1, None, "expected nothing")
guid2 = uuid.uuid4()
request_access1 = RequestAccess.build(
[
ApiAccessRule(
{
"a/b/c": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid2],
)
]
}
)
rules1 = request_access1.get_matching_rules("get", "")
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing")
self.assertEqual(rules1, None, "expected nothing")
def test_exact_match(self) -> None:
guid1 = uuid.uuid4()
request_access = RequestAccess.build(
[
ApiAccessRule(
{
"a/b/c": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1],
)
]
}
)
rules1 = request_access.get_matching_rules("get", "a/b/c")
rules2 = request_access.get_matching_rules("get", "b/b/e")
assert rules1 is not None
self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups")
self.assertEqual(rules1.allowed_groups_ids[0], guid1)
self.assertEqual(len(rules2.allowed_groups_ids), 0, "expected nothing")
self.assertEqual(rules2, None, "expected nothing")
def test_wildcard(self) -> None:
guid1 = uuid.uuid4()
request_access = RequestAccess.build(
[
ApiAccessRule(
{
"b/*/c": ApiAccessRule(
methods=["get"],
endpoint="b/*/c",
allowed_groups=[guid1],
)
]
}
)
rules = request_access.get_matching_rules("get", "b/b/c")
assert rules is not None
self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups")
self.assertEqual(rules.allowed_groups_ids[0], guid1)
@ -71,18 +70,16 @@ class TestRequestAccess(unittest.TestCase):
try:
RequestAccess.build(
[
ApiAccessRule(
{
"a/b/c": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1],
),
ApiAccessRule(
"a/b/c/": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[],
),
]
}
)
self.fail("this is expected to fail")
@ -95,22 +92,21 @@ class TestRequestAccess(unittest.TestCase):
guid2 = uuid.uuid4()
request_access = RequestAccess.build(
[
ApiAccessRule(
{
"a/*/c": ApiAccessRule(
methods=["get"],
endpoint="a/*/c",
allowed_groups=[guid1],
),
ApiAccessRule(
"a/b/c": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid2],
),
]
}
)
rules = request_access.get_matching_rules("get", "a/b/c")
assert rules is not None
self.assertEqual(
rules.allowed_groups_ids[0],
guid2,
@ -125,36 +121,36 @@ class TestRequestAccess(unittest.TestCase):
guid3 = uuid.uuid4()
request_access = RequestAccess.build(
[
ApiAccessRule(
{
"a/b/c": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1],
),
ApiAccessRule(
"f/*/c": ApiAccessRule(
methods=["get"],
endpoint="f/*/c",
allowed_groups=[guid2],
),
ApiAccessRule(
"a/b": ApiAccessRule(
methods=["post"],
endpoint="a/b",
allowed_groups=[guid3],
),
]
}
)
rules1 = request_access.get_matching_rules("get", "a/b/c/d")
assert rules1 is not None
self.assertEqual(
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
)
rules2 = request_access.get_matching_rules("get", "f/b/c/d")
assert rules2 is not None
self.assertEqual(
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c"
)
rules3 = request_access.get_matching_rules("post", "a/b/c/d")
assert rules3 is not None
self.assertEqual(
rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b"
)
@ -165,26 +161,26 @@ class TestRequestAccess(unittest.TestCase):
guid2 = uuid.uuid4()
request_access = RequestAccess.build(
[
ApiAccessRule(
{
"a/b/c": ApiAccessRule(
methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1],
),
ApiAccessRule(
"a/b/c/d": ApiAccessRule(
methods=["get"],
endpoint="a/b/c/d",
allowed_groups=[guid2],
),
]
}
)
rules1 = request_access.get_matching_rules("get", "a/b/c")
assert rules1 is not None
self.assertEqual(
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
)
rules2 = request_access.get_matching_rules("get", "a/b/c/d")
assert rules2 is not None
self.assertEqual(
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d"
)

View File

@ -266,7 +266,7 @@
"value": "[parameters('owner')]"
}
],
"linuxFxVersion": "Python|3.7",
"linuxFxVersion": "Python|3.8",
"alwaysOn": true,
"defaultDocuments": [],
"httpLoggingEnabled": true,

View File

@ -838,10 +838,15 @@ class AzureVmExtensionConfig(BaseModel):
class ApiAccessRule(BaseModel):
methods: List[str]
endpoint: str
allowed_groups: List[UUID]
Endpoint = str
# json dumps doesn't support UUID as dictionary key
PrincipalID = str
GroupId = UUID
class InstanceConfig(BaseModel):
# initial set of admins can only be set during deployment.
# if admins are set, only admins can update instance configs.
@ -857,6 +862,8 @@ class InstanceConfig(BaseModel):
)
extensions: Optional[AzureVmExtensionConfig]
proxy_vm_sku: str = Field(default="Standard_B2s")
api_access_rules: Optional[Dict[Endpoint, ApiAccessRule]] = None
group_membership: Optional[Dict[PrincipalID, List[GroupId]]] = None
def update(self, config: "InstanceConfig") -> None:
for field in config.__fields__: