Group membership check (#1074)

This commit is contained in:
Cheick Keita
2021-11-22 14:06:03 -08:00
committed by GitHub
parent 8ff6509e61
commit aa74550160
12 changed files with 429 additions and 78 deletions

View File

@ -168,14 +168,6 @@ def query_microsoft_graph_list(
raise GraphQueryError("Expected data containing a list of values", None)
def is_member_of(group_id: str, member_id: str) -> bool:
body = {"groupIds": [group_id]}
response = query_microsoft_graph_list(
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
)
return group_id in response
@cached
def get_scaleset_identity_resource_path() -> str:
scaleset_id_name = "%s-scalesetid" % get_instance_name()

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Protocol
from uuid import UUID
from ..config import InstanceConfig
from .creds import query_microsoft_graph_list
class GroupMembershipChecker(Protocol):
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
"""Check if member is part of at least one of the groups"""
if member_id in group_ids:
return True
groups = self.get_groups(member_id)
for g in group_ids:
if g in groups:
return True
return False
def get_groups(self, member_id: UUID) -> List[UUID]:
"""Gets all the groups of the provided member"""
def create_group_membership_checker() -> GroupMembershipChecker:
config = InstanceConfig.fetch()
if config.group_membership:
return StaticGroupMembership(config.group_membership)
else:
return AzureADGroupMembership()
class AzureADGroupMembership(GroupMembershipChecker):
def get_groups(self, member_id: UUID) -> List[UUID]:
response = query_microsoft_graph_list(
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
)
return response
class StaticGroupMembership(GroupMembershipChecker):
def __init__(self, memberships: Dict[str, List[UUID]]):
self.memberships = memberships
def get_groups(self, member_id: UUID) -> List[UUID]:
return self.memberships.get(str(member_id), [])

View File

@ -5,37 +5,76 @@
import json
import logging
import os
import urllib
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
from uuid import UUID
from azure.functions import HttpRequest, HttpResponse
from memoization import cached
from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error
from onefuzztypes.responses import BaseResponse
from pydantic import BaseModel # noqa: F401
from pydantic import ValidationError
from .azure.group_membership import create_group_membership_checker
from .config import InstanceConfig
from .orm import ModelMixin
from .request_access import RequestAccess
# We don't actually use these types at runtime at this time. Rather,
# these are used in a bound TypeVar. MyPy suggests to only import these
# types during type checking.
if TYPE_CHECKING:
from onefuzztypes.requests import BaseRequest # noqa: F401
from pydantic import BaseModel # noqa: F401
@cached(ttl=60)
def get_rules() -> Optional[RequestAccess]:
config = InstanceConfig.fetch()
if config.api_access_rules:
return RequestAccess.build(config.api_access_rules)
else:
return None
def check_access(req: HttpRequest) -> Optional[Error]:
if "ONEFUZZ_AAD_GROUP_ID" in os.environ:
message = "ONEFUZZ_AAD_GROUP_ID configuration not supported"
logging.error(message)
return Error(
code=ErrorCode.INVALID_CONFIGURATION,
errors=[message],
)
else:
rules = get_rules()
# Noting to enforce if there are no rules.
if not rules:
return None
path = urllib.parse.urlparse(req.url).path
rule = rules.get_matching_rules(req.method, path)
# No restriction defined on this endpoint.
if not rule:
return None
member_id = UUID(req.headers["x-ms-client-principal-id"])
try:
membership_checker = create_group_membership_checker()
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
if not allowed:
logging.error(
"unauthorized access: %s is not authorized to access in %s",
member_id,
req.url,
)
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["not approved to use this endpoint"],
)
except Exception as e:
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["unable to interact with graph", str(e)],
)
return None
def ok(
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]

View File

@ -1,8 +1,7 @@
from typing import Dict, List
from typing import Dict, List, Optional
from uuid import UUID
from onefuzztypes.models import ApiAccessRule
from pydantic import parse_raw_as
class RuleConflictError(Exception):
@ -41,7 +40,7 @@ class RequestAccess:
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
methods = list(map(lambda m: m.upper(), methods))
segments = path.split("/")
segments = [s for s in path.split("/") if s != ""]
if len(segments) == 0:
return
@ -71,15 +70,14 @@ class RequestAccess:
for method in methods:
current_node.rules[method] = rules
def get_matching_rules(self, method: str, path: str) -> Rules:
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
method = method.upper()
segments = path.split("/")
segments = [s for s in path.split("/") if s != ""]
current_node = self.root
current_rule = None
if method in current_node.rules:
current_rule = current_node.rules[method]
else:
current_rule = RequestAccess.Rules()
current_segment_index = 0
@ -98,17 +96,13 @@ class RequestAccess:
return current_rule
@classmethod
def parse_rules(cls, rules_data: str) -> "RequestAccess":
rules = parse_raw_as(List[ApiAccessRule], rules_data)
return cls.build(rules)
@classmethod
def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess":
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
request_access = RequestAccess()
for rule in rules:
for endpoint in rules:
rule = rules[endpoint]
request_access.__add_url__(
rule.methods,
rule.endpoint,
endpoint,
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
)