mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-14 11:08:06 +00:00
Group membership check (#1074)
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@ -233,7 +233,7 @@ jobs:
|
|||||||
pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl
|
pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl
|
||||||
pip install -r __app__/requirements.txt
|
pip install -r __app__/requirements.txt
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
pytest
|
pytest tests
|
||||||
flake8 .
|
flake8 .
|
||||||
bandit -r ./__app__/
|
bandit -r ./__app__/
|
||||||
black ./__app__/ ./tests --check
|
black ./__app__/ ./tests --check
|
||||||
|
@ -663,6 +663,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"definitions": {
|
"definitions": {
|
||||||
|
"ApiAccessRule": {
|
||||||
|
"properties": {
|
||||||
|
"allowed_groups": {
|
||||||
|
"items": {
|
||||||
|
"format": "uuid",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"title": "Allowed Groups",
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
"methods": {
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"title": "Methods",
|
||||||
|
"type": "array"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"methods",
|
||||||
|
"allowed_groups"
|
||||||
|
],
|
||||||
|
"title": "ApiAccessRule",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"AzureMonitorExtensionConfig": {
|
"AzureMonitorExtensionConfig": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"config_version": {
|
"config_version": {
|
||||||
@ -757,9 +782,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
|||||||
"title": "Allowed Aad Tenants",
|
"title": "Allowed Aad Tenants",
|
||||||
"type": "array"
|
"type": "array"
|
||||||
},
|
},
|
||||||
|
"api_access_rules": {
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/definitions/ApiAccessRule"
|
||||||
|
},
|
||||||
|
"title": "Api Access Rules",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"extensions": {
|
"extensions": {
|
||||||
"$ref": "#/definitions/AzureVmExtensionConfig"
|
"$ref": "#/definitions/AzureVmExtensionConfig"
|
||||||
},
|
},
|
||||||
|
"group_membership": {
|
||||||
|
"additionalProperties": {
|
||||||
|
"items": {
|
||||||
|
"format": "uuid",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
"title": "Group Membership",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"network_config": {
|
"network_config": {
|
||||||
"$ref": "#/definitions/NetworkConfig"
|
"$ref": "#/definitions/NetworkConfig"
|
||||||
},
|
},
|
||||||
@ -4933,6 +4976,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
|||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"definitions": {
|
"definitions": {
|
||||||
|
"ApiAccessRule": {
|
||||||
|
"properties": {
|
||||||
|
"allowed_groups": {
|
||||||
|
"items": {
|
||||||
|
"format": "uuid",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"title": "Allowed Groups",
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
"methods": {
|
||||||
|
"items": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"title": "Methods",
|
||||||
|
"type": "array"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"methods",
|
||||||
|
"allowed_groups"
|
||||||
|
],
|
||||||
|
"title": "ApiAccessRule",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"Architecture": {
|
"Architecture": {
|
||||||
"description": "An enumeration.",
|
"description": "An enumeration.",
|
||||||
"enum": [
|
"enum": [
|
||||||
@ -5856,9 +5924,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
|||||||
"title": "Allowed Aad Tenants",
|
"title": "Allowed Aad Tenants",
|
||||||
"type": "array"
|
"type": "array"
|
||||||
},
|
},
|
||||||
|
"api_access_rules": {
|
||||||
|
"additionalProperties": {
|
||||||
|
"$ref": "#/definitions/ApiAccessRule"
|
||||||
|
},
|
||||||
|
"title": "Api Access Rules",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"extensions": {
|
"extensions": {
|
||||||
"$ref": "#/definitions/AzureVmExtensionConfig"
|
"$ref": "#/definitions/AzureVmExtensionConfig"
|
||||||
},
|
},
|
||||||
|
"group_membership": {
|
||||||
|
"additionalProperties": {
|
||||||
|
"items": {
|
||||||
|
"format": "uuid",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
"title": "Group Membership",
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
"network_config": {
|
"network_config": {
|
||||||
"$ref": "#/definitions/NetworkConfig"
|
"$ref": "#/definitions/NetworkConfig"
|
||||||
},
|
},
|
||||||
|
@ -168,14 +168,6 @@ def query_microsoft_graph_list(
|
|||||||
raise GraphQueryError("Expected data containing a list of values", None)
|
raise GraphQueryError("Expected data containing a list of values", None)
|
||||||
|
|
||||||
|
|
||||||
def is_member_of(group_id: str, member_id: str) -> bool:
|
|
||||||
body = {"groupIds": [group_id]}
|
|
||||||
response = query_microsoft_graph_list(
|
|
||||||
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
|
|
||||||
)
|
|
||||||
return group_id in response
|
|
||||||
|
|
||||||
|
|
||||||
@cached
|
@cached
|
||||||
def get_scaleset_identity_resource_path() -> str:
|
def get_scaleset_identity_resource_path() -> str:
|
||||||
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
||||||
|
51
src/api-service/__app__/onefuzzlib/azure/group_membership.py
Normal file
51
src/api-service/__app__/onefuzzlib/azure/group_membership.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
from typing import Dict, List, Protocol
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from ..config import InstanceConfig
|
||||||
|
from .creds import query_microsoft_graph_list
|
||||||
|
|
||||||
|
|
||||||
|
class GroupMembershipChecker(Protocol):
|
||||||
|
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
|
||||||
|
"""Check if member is part of at least one of the groups"""
|
||||||
|
if member_id in group_ids:
|
||||||
|
return True
|
||||||
|
|
||||||
|
groups = self.get_groups(member_id)
|
||||||
|
for g in group_ids:
|
||||||
|
if g in groups:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||||
|
"""Gets all the groups of the provided member"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_group_membership_checker() -> GroupMembershipChecker:
|
||||||
|
config = InstanceConfig.fetch()
|
||||||
|
if config.group_membership:
|
||||||
|
return StaticGroupMembership(config.group_membership)
|
||||||
|
else:
|
||||||
|
return AzureADGroupMembership()
|
||||||
|
|
||||||
|
|
||||||
|
class AzureADGroupMembership(GroupMembershipChecker):
|
||||||
|
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||||
|
response = query_microsoft_graph_list(
|
||||||
|
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class StaticGroupMembership(GroupMembershipChecker):
|
||||||
|
def __init__(self, memberships: Dict[str, List[UUID]]):
|
||||||
|
self.memberships = memberships
|
||||||
|
|
||||||
|
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||||
|
return self.memberships.get(str(member_id), [])
|
@ -5,37 +5,76 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import urllib
|
||||||
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from azure.functions import HttpRequest, HttpResponse
|
from azure.functions import HttpRequest, HttpResponse
|
||||||
|
from memoization import cached
|
||||||
from onefuzztypes.enums import ErrorCode
|
from onefuzztypes.enums import ErrorCode
|
||||||
from onefuzztypes.models import Error
|
from onefuzztypes.models import Error
|
||||||
from onefuzztypes.responses import BaseResponse
|
from onefuzztypes.responses import BaseResponse
|
||||||
|
from pydantic import BaseModel # noqa: F401
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from .azure.group_membership import create_group_membership_checker
|
||||||
|
from .config import InstanceConfig
|
||||||
from .orm import ModelMixin
|
from .orm import ModelMixin
|
||||||
|
from .request_access import RequestAccess
|
||||||
|
|
||||||
# We don't actually use these types at runtime at this time. Rather,
|
# We don't actually use these types at runtime at this time. Rather,
|
||||||
# these are used in a bound TypeVar. MyPy suggests to only import these
|
# these are used in a bound TypeVar. MyPy suggests to only import these
|
||||||
# types during type checking.
|
# types during type checking.
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from onefuzztypes.requests import BaseRequest # noqa: F401
|
from onefuzztypes.requests import BaseRequest # noqa: F401
|
||||||
from pydantic import BaseModel # noqa: F401
|
|
||||||
|
|
||||||
|
@cached(ttl=60)
|
||||||
|
def get_rules() -> Optional[RequestAccess]:
|
||||||
|
config = InstanceConfig.fetch()
|
||||||
|
if config.api_access_rules:
|
||||||
|
return RequestAccess.build(config.api_access_rules)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_access(req: HttpRequest) -> Optional[Error]:
|
def check_access(req: HttpRequest) -> Optional[Error]:
|
||||||
if "ONEFUZZ_AAD_GROUP_ID" in os.environ:
|
rules = get_rules()
|
||||||
message = "ONEFUZZ_AAD_GROUP_ID configuration not supported"
|
|
||||||
logging.error(message)
|
# Noting to enforce if there are no rules.
|
||||||
return Error(
|
if not rules:
|
||||||
code=ErrorCode.INVALID_CONFIGURATION,
|
|
||||||
errors=[message],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
path = urllib.parse.urlparse(req.url).path
|
||||||
|
rule = rules.get_matching_rules(req.method, path)
|
||||||
|
|
||||||
|
# No restriction defined on this endpoint.
|
||||||
|
if not rule:
|
||||||
|
return None
|
||||||
|
|
||||||
|
member_id = UUID(req.headers["x-ms-client-principal-id"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
membership_checker = create_group_membership_checker()
|
||||||
|
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
|
||||||
|
if not allowed:
|
||||||
|
logging.error(
|
||||||
|
"unauthorized access: %s is not authorized to access in %s",
|
||||||
|
member_id,
|
||||||
|
req.url,
|
||||||
|
)
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
errors=["not approved to use this endpoint"],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return Error(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
errors=["unable to interact with graph", str(e)],
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def ok(
|
def ok(
|
||||||
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
|
data: Union[BaseResponse, Sequence[BaseResponse], ModelMixin, Sequence[ModelMixin]]
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from onefuzztypes.models import ApiAccessRule
|
from onefuzztypes.models import ApiAccessRule
|
||||||
from pydantic import parse_raw_as
|
|
||||||
|
|
||||||
|
|
||||||
class RuleConflictError(Exception):
|
class RuleConflictError(Exception):
|
||||||
@ -41,7 +40,7 @@ class RequestAccess:
|
|||||||
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
|
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
|
||||||
methods = list(map(lambda m: m.upper(), methods))
|
methods = list(map(lambda m: m.upper(), methods))
|
||||||
|
|
||||||
segments = path.split("/")
|
segments = [s for s in path.split("/") if s != ""]
|
||||||
if len(segments) == 0:
|
if len(segments) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -71,15 +70,14 @@ class RequestAccess:
|
|||||||
for method in methods:
|
for method in methods:
|
||||||
current_node.rules[method] = rules
|
current_node.rules[method] = rules
|
||||||
|
|
||||||
def get_matching_rules(self, method: str, path: str) -> Rules:
|
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
|
||||||
method = method.upper()
|
method = method.upper()
|
||||||
segments = path.split("/")
|
segments = [s for s in path.split("/") if s != ""]
|
||||||
current_node = self.root
|
current_node = self.root
|
||||||
|
current_rule = None
|
||||||
|
|
||||||
if method in current_node.rules:
|
if method in current_node.rules:
|
||||||
current_rule = current_node.rules[method]
|
current_rule = current_node.rules[method]
|
||||||
else:
|
|
||||||
current_rule = RequestAccess.Rules()
|
|
||||||
|
|
||||||
current_segment_index = 0
|
current_segment_index = 0
|
||||||
|
|
||||||
@ -98,17 +96,13 @@ class RequestAccess:
|
|||||||
return current_rule
|
return current_rule
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_rules(cls, rules_data: str) -> "RequestAccess":
|
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
|
||||||
rules = parse_raw_as(List[ApiAccessRule], rules_data)
|
|
||||||
return cls.build(rules)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess":
|
|
||||||
request_access = RequestAccess()
|
request_access = RequestAccess()
|
||||||
for rule in rules:
|
for endpoint in rules:
|
||||||
|
rule = rules[endpoint]
|
||||||
request_access.__add_url__(
|
request_access.__add_url__(
|
||||||
rule.methods,
|
rule.methods,
|
||||||
rule.endpoint,
|
endpoint,
|
||||||
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
|
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
144
src/api-service/functional_tests/api_restriction_test.py
Normal file
144
src/api-service/functional_tests/api_restriction_test.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from azure.cli.core import get_default_cli
|
||||||
|
from onefuzz.api import Onefuzz
|
||||||
|
from onefuzztypes.models import ApiAccessRule
|
||||||
|
|
||||||
|
|
||||||
|
def az_cli(args: List[str]) -> Any:
|
||||||
|
cli = get_default_cli()
|
||||||
|
cli.logging_cls
|
||||||
|
cli.invoke(args, out_file=open(os.devnull, "w"))
|
||||||
|
if cli.result.result:
|
||||||
|
return cli.result.result
|
||||||
|
elif cli.result.error:
|
||||||
|
raise cli.result.error
|
||||||
|
|
||||||
|
|
||||||
|
class APIRestrictionTests:
|
||||||
|
def __init__(
|
||||||
|
self, resource_group: str = None, onefuzz_config_path: str = None
|
||||||
|
) -> None:
|
||||||
|
self.onefuzz = Onefuzz(config_path=onefuzz_config_path)
|
||||||
|
self.intial_config = self.onefuzz.instance_config.get()
|
||||||
|
|
||||||
|
self.instance_name = urlparse(self.onefuzz.config().endpoint).netloc.split(".")[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
if resource_group:
|
||||||
|
self.resource_group = resource_group
|
||||||
|
else:
|
||||||
|
self.resource_group = self.instance_name
|
||||||
|
|
||||||
|
def restore_config(self) -> None:
|
||||||
|
self.onefuzz.instance_config.update(self.intial_config)
|
||||||
|
|
||||||
|
def assign(self, group_id: UUID, member_id: UUID) -> None:
|
||||||
|
instance_config = self.onefuzz.instance_config.get()
|
||||||
|
if instance_config.group_membership is None:
|
||||||
|
instance_config.group_membership = {}
|
||||||
|
|
||||||
|
if member_id not in instance_config.group_membership:
|
||||||
|
instance_config.group_membership[member_id] = []
|
||||||
|
|
||||||
|
if group_id not in instance_config.group_membership[member_id]:
|
||||||
|
instance_config.group_membership[member_id].append(group_id)
|
||||||
|
|
||||||
|
self.onefuzz.instance_config.update(instance_config)
|
||||||
|
|
||||||
|
def assign_current_user(self, group_id: UUID) -> None:
|
||||||
|
onefuzz_service_appId = az_cli(
|
||||||
|
[
|
||||||
|
"ad",
|
||||||
|
"signed-in-user",
|
||||||
|
"show",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
member_id = UUID(onefuzz_service_appId["objectId"])
|
||||||
|
print(f"adding user {member_id}")
|
||||||
|
self.assign(group_id, member_id)
|
||||||
|
|
||||||
|
def test_restriction_on_current_user(self) -> None:
|
||||||
|
|
||||||
|
print("Checking that the current user can get jobs")
|
||||||
|
self.onefuzz.jobs.list()
|
||||||
|
|
||||||
|
print("Creating test group")
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
|
||||||
|
print("Adding restriction to the jobs endpoint")
|
||||||
|
instance_config = self.onefuzz.instance_config.get()
|
||||||
|
if instance_config.api_access_rules is None:
|
||||||
|
instance_config.api_access_rules = {}
|
||||||
|
|
||||||
|
instance_config.api_access_rules["/api/jobs"] = ApiAccessRule(
|
||||||
|
allowed_groups=[group_id],
|
||||||
|
methods=["GET"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.onefuzz.instance_config.update(instance_config)
|
||||||
|
restart_instance(self.instance_name, self.resource_group)
|
||||||
|
time.sleep(20)
|
||||||
|
print("Checking that the current user cannot get jobs")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.onefuzz.jobs.list()
|
||||||
|
failed = False
|
||||||
|
except Exception:
|
||||||
|
failed = True
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not failed:
|
||||||
|
raise Exception("Current user was able to get jobs")
|
||||||
|
|
||||||
|
print("Assigning current user to test group")
|
||||||
|
self.assign_current_user(group_id)
|
||||||
|
restart_instance(self.instance_name, self.resource_group)
|
||||||
|
time.sleep(20)
|
||||||
|
|
||||||
|
print("Checking that the current user can get jobs")
|
||||||
|
self.onefuzz.jobs.list()
|
||||||
|
|
||||||
|
|
||||||
|
def restart_instance(instance_name: str, resource_group: str) -> None:
|
||||||
|
print("Restarting instance")
|
||||||
|
az_cli(
|
||||||
|
[
|
||||||
|
"functionapp",
|
||||||
|
"restart",
|
||||||
|
"--name",
|
||||||
|
f"{instance_name}",
|
||||||
|
"--resource-group",
|
||||||
|
f"{resource_group}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config_path", default=None)
|
||||||
|
parser.add_argument("--resource_group", default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
tester = APIRestrictionTests(args.resource_group, args.config_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("test current user restriction")
|
||||||
|
tester.test_restriction_on_current_user()
|
||||||
|
finally:
|
||||||
|
tester.restore_config()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
4
src/api-service/functional_tests/requirements.txt
Normal file
4
src/api-service/functional_tests/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
../../cli
|
||||||
|
../../pytypes
|
||||||
|
azure-cli-core==2.27.2
|
||||||
|
azure-cli==2.27.2
|
38
src/api-service/tests/test_group_membership.py
Normal file
38
src/api-service/tests/test_group_membership.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import unittest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from __app__.onefuzzlib.azure.group_membership import (
|
||||||
|
GroupMembershipChecker,
|
||||||
|
StaticGroupMembership,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestAccess(unittest.TestCase):
|
||||||
|
def test_empty(self) -> None:
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
checker: GroupMembershipChecker = StaticGroupMembership({})
|
||||||
|
|
||||||
|
self.assertFalse(checker.is_member([group_id], user_id))
|
||||||
|
self.assertTrue(checker.is_member([user_id], user_id))
|
||||||
|
|
||||||
|
def test_matching_user_id(self) -> None:
|
||||||
|
group_id = uuid.uuid4()
|
||||||
|
user_id1 = uuid.uuid4()
|
||||||
|
user_id2 = uuid.uuid4()
|
||||||
|
|
||||||
|
checker: GroupMembershipChecker = StaticGroupMembership(
|
||||||
|
{str(user_id1): [group_id]}
|
||||||
|
)
|
||||||
|
self.assertTrue(checker.is_member([user_id1], user_id1))
|
||||||
|
self.assertFalse(checker.is_member([user_id1], user_id2))
|
||||||
|
|
||||||
|
def test_user_in_group(self) -> None:
|
||||||
|
group_id1 = uuid.uuid4()
|
||||||
|
group_id2 = uuid.uuid4()
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
checker: GroupMembershipChecker = StaticGroupMembership(
|
||||||
|
{str(user_id): [group_id1]}
|
||||||
|
)
|
||||||
|
self.assertTrue(checker.is_member([group_id1], user_id))
|
||||||
|
self.assertFalse(checker.is_member([group_id2], user_id))
|
@ -8,61 +8,60 @@ from __app__.onefuzzlib.request_access import RequestAccess, RuleConflictError
|
|||||||
|
|
||||||
class TestRequestAccess(unittest.TestCase):
|
class TestRequestAccess(unittest.TestCase):
|
||||||
def test_empty(self) -> None:
|
def test_empty(self) -> None:
|
||||||
request_access1 = RequestAccess.build([])
|
request_access1 = RequestAccess.build({})
|
||||||
rules1 = request_access1.get_matching_rules("get", "a/b/c")
|
rules1 = request_access1.get_matching_rules("get", "a/b/c")
|
||||||
|
|
||||||
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing")
|
self.assertEqual(rules1, None, "expected nothing")
|
||||||
|
|
||||||
guid2 = uuid.uuid4()
|
guid2 = uuid.uuid4()
|
||||||
request_access1 = RequestAccess.build(
|
request_access1 = RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"a/b/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[guid2],
|
allowed_groups=[guid2],
|
||||||
)
|
)
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
rules1 = request_access1.get_matching_rules("get", "")
|
rules1 = request_access1.get_matching_rules("get", "")
|
||||||
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing")
|
self.assertEqual(rules1, None, "expected nothing")
|
||||||
|
|
||||||
def test_exact_match(self) -> None:
|
def test_exact_match(self) -> None:
|
||||||
|
|
||||||
guid1 = uuid.uuid4()
|
guid1 = uuid.uuid4()
|
||||||
|
|
||||||
request_access = RequestAccess.build(
|
request_access = RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"a/b/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[guid1],
|
allowed_groups=[guid1],
|
||||||
)
|
)
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
rules1 = request_access.get_matching_rules("get", "a/b/c")
|
rules1 = request_access.get_matching_rules("get", "a/b/c")
|
||||||
rules2 = request_access.get_matching_rules("get", "b/b/e")
|
rules2 = request_access.get_matching_rules("get", "b/b/e")
|
||||||
|
|
||||||
|
assert rules1 is not None
|
||||||
self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups")
|
self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups")
|
||||||
self.assertEqual(rules1.allowed_groups_ids[0], guid1)
|
self.assertEqual(rules1.allowed_groups_ids[0], guid1)
|
||||||
|
|
||||||
self.assertEqual(len(rules2.allowed_groups_ids), 0, "expected nothing")
|
self.assertEqual(rules2, None, "expected nothing")
|
||||||
|
|
||||||
def test_wildcard(self) -> None:
|
def test_wildcard(self) -> None:
|
||||||
guid1 = uuid.uuid4()
|
guid1 = uuid.uuid4()
|
||||||
|
|
||||||
request_access = RequestAccess.build(
|
request_access = RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"b/*/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="b/*/c",
|
|
||||||
allowed_groups=[guid1],
|
allowed_groups=[guid1],
|
||||||
)
|
)
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
rules = request_access.get_matching_rules("get", "b/b/c")
|
rules = request_access.get_matching_rules("get", "b/b/c")
|
||||||
|
|
||||||
|
assert rules is not None
|
||||||
self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups")
|
self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups")
|
||||||
self.assertEqual(rules.allowed_groups_ids[0], guid1)
|
self.assertEqual(rules.allowed_groups_ids[0], guid1)
|
||||||
|
|
||||||
@ -71,18 +70,16 @@ class TestRequestAccess(unittest.TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
RequestAccess.build(
|
RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"a/b/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[guid1],
|
allowed_groups=[guid1],
|
||||||
),
|
),
|
||||||
ApiAccessRule(
|
"a/b/c/": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[],
|
allowed_groups=[],
|
||||||
),
|
),
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fail("this is expected to fail")
|
self.fail("this is expected to fail")
|
||||||
@ -95,22 +92,21 @@ class TestRequestAccess(unittest.TestCase):
|
|||||||
guid2 = uuid.uuid4()
|
guid2 = uuid.uuid4()
|
||||||
|
|
||||||
request_access = RequestAccess.build(
|
request_access = RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"a/*/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/*/c",
|
|
||||||
allowed_groups=[guid1],
|
allowed_groups=[guid1],
|
||||||
),
|
),
|
||||||
ApiAccessRule(
|
"a/b/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[guid2],
|
allowed_groups=[guid2],
|
||||||
),
|
),
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
rules = request_access.get_matching_rules("get", "a/b/c")
|
rules = request_access.get_matching_rules("get", "a/b/c")
|
||||||
|
|
||||||
|
assert rules is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rules.allowed_groups_ids[0],
|
rules.allowed_groups_ids[0],
|
||||||
guid2,
|
guid2,
|
||||||
@ -125,36 +121,36 @@ class TestRequestAccess(unittest.TestCase):
|
|||||||
guid3 = uuid.uuid4()
|
guid3 = uuid.uuid4()
|
||||||
|
|
||||||
request_access = RequestAccess.build(
|
request_access = RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"a/b/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[guid1],
|
allowed_groups=[guid1],
|
||||||
),
|
),
|
||||||
ApiAccessRule(
|
"f/*/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="f/*/c",
|
|
||||||
allowed_groups=[guid2],
|
allowed_groups=[guid2],
|
||||||
),
|
),
|
||||||
ApiAccessRule(
|
"a/b": ApiAccessRule(
|
||||||
methods=["post"],
|
methods=["post"],
|
||||||
endpoint="a/b",
|
|
||||||
allowed_groups=[guid3],
|
allowed_groups=[guid3],
|
||||||
),
|
),
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
rules1 = request_access.get_matching_rules("get", "a/b/c/d")
|
rules1 = request_access.get_matching_rules("get", "a/b/c/d")
|
||||||
|
assert rules1 is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
|
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
|
||||||
)
|
)
|
||||||
|
|
||||||
rules2 = request_access.get_matching_rules("get", "f/b/c/d")
|
rules2 = request_access.get_matching_rules("get", "f/b/c/d")
|
||||||
|
assert rules2 is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c"
|
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c"
|
||||||
)
|
)
|
||||||
|
|
||||||
rules3 = request_access.get_matching_rules("post", "a/b/c/d")
|
rules3 = request_access.get_matching_rules("post", "a/b/c/d")
|
||||||
|
assert rules3 is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b"
|
rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b"
|
||||||
)
|
)
|
||||||
@ -165,26 +161,26 @@ class TestRequestAccess(unittest.TestCase):
|
|||||||
guid2 = uuid.uuid4()
|
guid2 = uuid.uuid4()
|
||||||
|
|
||||||
request_access = RequestAccess.build(
|
request_access = RequestAccess.build(
|
||||||
[
|
{
|
||||||
ApiAccessRule(
|
"a/b/c": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c",
|
|
||||||
allowed_groups=[guid1],
|
allowed_groups=[guid1],
|
||||||
),
|
),
|
||||||
ApiAccessRule(
|
"a/b/c/d": ApiAccessRule(
|
||||||
methods=["get"],
|
methods=["get"],
|
||||||
endpoint="a/b/c/d",
|
|
||||||
allowed_groups=[guid2],
|
allowed_groups=[guid2],
|
||||||
),
|
),
|
||||||
]
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
rules1 = request_access.get_matching_rules("get", "a/b/c")
|
rules1 = request_access.get_matching_rules("get", "a/b/c")
|
||||||
|
assert rules1 is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
|
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
|
||||||
)
|
)
|
||||||
|
|
||||||
rules2 = request_access.get_matching_rules("get", "a/b/c/d")
|
rules2 = request_access.get_matching_rules("get", "a/b/c/d")
|
||||||
|
assert rules2 is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d"
|
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d"
|
||||||
)
|
)
|
||||||
|
@ -266,7 +266,7 @@
|
|||||||
"value": "[parameters('owner')]"
|
"value": "[parameters('owner')]"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"linuxFxVersion": "Python|3.7",
|
"linuxFxVersion": "Python|3.8",
|
||||||
"alwaysOn": true,
|
"alwaysOn": true,
|
||||||
"defaultDocuments": [],
|
"defaultDocuments": [],
|
||||||
"httpLoggingEnabled": true,
|
"httpLoggingEnabled": true,
|
||||||
|
@ -838,10 +838,15 @@ class AzureVmExtensionConfig(BaseModel):
|
|||||||
|
|
||||||
class ApiAccessRule(BaseModel):
|
class ApiAccessRule(BaseModel):
|
||||||
methods: List[str]
|
methods: List[str]
|
||||||
endpoint: str
|
|
||||||
allowed_groups: List[UUID]
|
allowed_groups: List[UUID]
|
||||||
|
|
||||||
|
|
||||||
|
Endpoint = str
|
||||||
|
# json dumps doesn't support UUID as dictionary key
|
||||||
|
PrincipalID = str
|
||||||
|
GroupId = UUID
|
||||||
|
|
||||||
|
|
||||||
class InstanceConfig(BaseModel):
|
class InstanceConfig(BaseModel):
|
||||||
# initial set of admins can only be set during deployment.
|
# initial set of admins can only be set during deployment.
|
||||||
# if admins are set, only admins can update instance configs.
|
# if admins are set, only admins can update instance configs.
|
||||||
@ -857,6 +862,8 @@ class InstanceConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
extensions: Optional[AzureVmExtensionConfig]
|
extensions: Optional[AzureVmExtensionConfig]
|
||||||
proxy_vm_sku: str = Field(default="Standard_B2s")
|
proxy_vm_sku: str = Field(default="Standard_B2s")
|
||||||
|
api_access_rules: Optional[Dict[Endpoint, ApiAccessRule]] = None
|
||||||
|
group_membership: Optional[Dict[PrincipalID, List[GroupId]]] = None
|
||||||
|
|
||||||
def update(self, config: "InstanceConfig") -> None:
|
def update(self, config: "InstanceConfig") -> None:
|
||||||
for field in config.__fields__:
|
for field in config.__fields__:
|
||||||
|
Reference in New Issue
Block a user