mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 03:48:09 +00:00
Group membership check (#1074)
This commit is contained in:
@ -168,14 +168,6 @@ def query_microsoft_graph_list(
|
||||
raise GraphQueryError("Expected data containing a list of values", None)
|
||||
|
||||
|
||||
def is_member_of(group_id: str, member_id: str) -> bool:
|
||||
body = {"groupIds": [group_id]}
|
||||
response = query_microsoft_graph_list(
|
||||
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
|
||||
)
|
||||
return group_id in response
|
||||
|
||||
|
||||
@cached
|
||||
def get_scaleset_identity_resource_path() -> str:
|
||||
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
||||
|
51
src/api-service/__app__/onefuzzlib/azure/group_membership.py
Normal file
51
src/api-service/__app__/onefuzzlib/azure/group_membership.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from ..config import InstanceConfig
|
||||
from .creds import query_microsoft_graph_list
|
||||
|
||||
|
||||
class GroupMembershipChecker(Protocol):
|
||||
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
|
||||
"""Check if member is part of at least one of the groups"""
|
||||
if member_id in group_ids:
|
||||
return True
|
||||
|
||||
groups = self.get_groups(member_id)
|
||||
for g in group_ids:
|
||||
if g in groups:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
"""Gets all the groups of the provided member"""
|
||||
|
||||
|
||||
def create_group_membership_checker() -> GroupMembershipChecker:
|
||||
config = InstanceConfig.fetch()
|
||||
if config.group_membership:
|
||||
return StaticGroupMembership(config.group_membership)
|
||||
else:
|
||||
return AzureADGroupMembership()
|
||||
|
||||
|
||||
class AzureADGroupMembership(GroupMembershipChecker):
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
response = query_microsoft_graph_list(
|
||||
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class StaticGroupMembership(GroupMembershipChecker):
|
||||
def __init__(self, memberships: Dict[str, List[UUID]]):
|
||||
self.memberships = memberships
|
||||
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
return self.memberships.get(str(member_id), [])
|
@ -5,37 +5,76 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import urllib
|
||||
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.functions import HttpRequest, HttpResponse
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.responses import BaseResponse
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .azure.group_membership import create_group_membership_checker
|
||||
from .config import InstanceConfig
|
||||
from .orm import ModelMixin
|
||||
from .request_access import RequestAccess
|
||||
|
||||
# We don't actually use these types at runtime at this time. Rather,
|
||||
# these are used in a bound TypeVar. MyPy suggests to only import these
|
||||
# types during type checking.
|
||||
if TYPE_CHECKING:
|
||||
from onefuzztypes.requests import BaseRequest # noqa: F401
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_rules() -> Optional[RequestAccess]:
|
||||
config = InstanceConfig.fetch()
|
||||
if config.api_access_rules:
|
||||
return RequestAccess.build(config.api_access_rules)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_access(req: HttpRequest) -> Optional[Error]:
|
||||
if "ONEFUZZ_AAD_GROUP_ID" in os.environ:
|
||||
message = "ONEFUZZ_AAD_GROUP_ID configuration not supported"
|
||||
logging.error(message)
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_CONFIGURATION,
|
||||
errors=[message],
|
||||
)
|
||||
else:
|
||||
rules = get_rules()
|
||||
|
||||
# Noting to enforce if there are no rules.
|
||||
if not rules:
|
||||
return None
|
||||
|
||||
path = urllib.parse.urlparse(req.url).path
|
||||
rule = rules.get_matching_rules(req.method, path)
|
||||
|
||||
# No restriction defined on this endpoint.
|
||||
if not rule:
|
||||
return None
|
||||
|
||||
member_id = UUID(req.headers["x-ms-client-principal-id"])
|
||||
|
||||
try:
|
||||
membership_checker = create_group_membership_checker()
|
||||
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
|
||||
if not allowed:
|
||||
logging.error(
|
||||
"unauthorized access: %s is not authorized to access in %s",
|
||||
member_id,
|
||||
req.url,
|
||||
)
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["not approved to use this endpoint"],
|
||||
)
|
||||
except Exception as e:
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["unable to interact with graph", str(e)],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def ok(
|
||||
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
|
||||
|
@ -1,8 +1,7 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.models import ApiAccessRule
|
||||
from pydantic import parse_raw_as
|
||||
|
||||
|
||||
class RuleConflictError(Exception):
|
||||
@ -41,7 +40,7 @@ class RequestAccess:
|
||||
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
|
||||
methods = list(map(lambda m: m.upper(), methods))
|
||||
|
||||
segments = path.split("/")
|
||||
segments = [s for s in path.split("/") if s != ""]
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
@ -71,15 +70,14 @@ class RequestAccess:
|
||||
for method in methods:
|
||||
current_node.rules[method] = rules
|
||||
|
||||
def get_matching_rules(self, method: str, path: str) -> Rules:
|
||||
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
|
||||
method = method.upper()
|
||||
segments = path.split("/")
|
||||
segments = [s for s in path.split("/") if s != ""]
|
||||
current_node = self.root
|
||||
current_rule = None
|
||||
|
||||
if method in current_node.rules:
|
||||
current_rule = current_node.rules[method]
|
||||
else:
|
||||
current_rule = RequestAccess.Rules()
|
||||
|
||||
current_segment_index = 0
|
||||
|
||||
@ -98,17 +96,13 @@ class RequestAccess:
|
||||
return current_rule
|
||||
|
||||
@classmethod
|
||||
def parse_rules(cls, rules_data: str) -> "RequestAccess":
|
||||
rules = parse_raw_as(List[ApiAccessRule], rules_data)
|
||||
return cls.build(rules)
|
||||
|
||||
@classmethod
|
||||
def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess":
|
||||
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
|
||||
request_access = RequestAccess()
|
||||
for rule in rules:
|
||||
for endpoint in rules:
|
||||
rule = rules[endpoint]
|
||||
request_access.__add_url__(
|
||||
rule.methods,
|
||||
rule.endpoint,
|
||||
endpoint,
|
||||
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user