mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-13 18:48:09 +00:00
Group membership check (#1074)
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@ -233,7 +233,7 @@ jobs:
|
||||
pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl
|
||||
pip install -r __app__/requirements.txt
|
||||
pip install -r requirements-dev.txt
|
||||
pytest
|
||||
pytest tests
|
||||
flake8 .
|
||||
bandit -r ./__app__/
|
||||
black ./__app__/ ./tests --check
|
||||
|
@ -663,6 +663,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
||||
```json
|
||||
{
|
||||
"definitions": {
|
||||
"ApiAccessRule": {
|
||||
"properties": {
|
||||
"allowed_groups": {
|
||||
"items": {
|
||||
"format": "uuid",
|
||||
"type": "string"
|
||||
},
|
||||
"title": "Allowed Groups",
|
||||
"type": "array"
|
||||
},
|
||||
"methods": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"title": "Methods",
|
||||
"type": "array"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"methods",
|
||||
"allowed_groups"
|
||||
],
|
||||
"title": "ApiAccessRule",
|
||||
"type": "object"
|
||||
},
|
||||
"AzureMonitorExtensionConfig": {
|
||||
"properties": {
|
||||
"config_version": {
|
||||
@ -757,9 +782,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
||||
"title": "Allowed Aad Tenants",
|
||||
"type": "array"
|
||||
},
|
||||
"api_access_rules": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/definitions/ApiAccessRule"
|
||||
},
|
||||
"title": "Api Access Rules",
|
||||
"type": "object"
|
||||
},
|
||||
"extensions": {
|
||||
"$ref": "#/definitions/AzureVmExtensionConfig"
|
||||
},
|
||||
"group_membership": {
|
||||
"additionalProperties": {
|
||||
"items": {
|
||||
"format": "uuid",
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
"title": "Group Membership",
|
||||
"type": "object"
|
||||
},
|
||||
"network_config": {
|
||||
"$ref": "#/definitions/NetworkConfig"
|
||||
},
|
||||
@ -4933,6 +4976,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
||||
```json
|
||||
{
|
||||
"definitions": {
|
||||
"ApiAccessRule": {
|
||||
"properties": {
|
||||
"allowed_groups": {
|
||||
"items": {
|
||||
"format": "uuid",
|
||||
"type": "string"
|
||||
},
|
||||
"title": "Allowed Groups",
|
||||
"type": "array"
|
||||
},
|
||||
"methods": {
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"title": "Methods",
|
||||
"type": "array"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"methods",
|
||||
"allowed_groups"
|
||||
],
|
||||
"title": "ApiAccessRule",
|
||||
"type": "object"
|
||||
},
|
||||
"Architecture": {
|
||||
"description": "An enumeration.",
|
||||
"enum": [
|
||||
@ -5856,9 +5924,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
|
||||
"title": "Allowed Aad Tenants",
|
||||
"type": "array"
|
||||
},
|
||||
"api_access_rules": {
|
||||
"additionalProperties": {
|
||||
"$ref": "#/definitions/ApiAccessRule"
|
||||
},
|
||||
"title": "Api Access Rules",
|
||||
"type": "object"
|
||||
},
|
||||
"extensions": {
|
||||
"$ref": "#/definitions/AzureVmExtensionConfig"
|
||||
},
|
||||
"group_membership": {
|
||||
"additionalProperties": {
|
||||
"items": {
|
||||
"format": "uuid",
|
||||
"type": "string"
|
||||
},
|
||||
"type": "array"
|
||||
},
|
||||
"title": "Group Membership",
|
||||
"type": "object"
|
||||
},
|
||||
"network_config": {
|
||||
"$ref": "#/definitions/NetworkConfig"
|
||||
},
|
||||
|
@ -168,14 +168,6 @@ def query_microsoft_graph_list(
|
||||
raise GraphQueryError("Expected data containing a list of values", None)
|
||||
|
||||
|
||||
def is_member_of(group_id: str, member_id: str) -> bool:
|
||||
body = {"groupIds": [group_id]}
|
||||
response = query_microsoft_graph_list(
|
||||
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
|
||||
)
|
||||
return group_id in response
|
||||
|
||||
|
||||
@cached
|
||||
def get_scaleset_identity_resource_path() -> str:
|
||||
scaleset_id_name = "%s-scalesetid" % get_instance_name()
|
||||
|
51
src/api-service/__app__/onefuzzlib/azure/group_membership.py
Normal file
51
src/api-service/__app__/onefuzzlib/azure/group_membership.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Dict, List, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from ..config import InstanceConfig
|
||||
from .creds import query_microsoft_graph_list
|
||||
|
||||
|
||||
class GroupMembershipChecker(Protocol):
|
||||
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
|
||||
"""Check if member is part of at least one of the groups"""
|
||||
if member_id in group_ids:
|
||||
return True
|
||||
|
||||
groups = self.get_groups(member_id)
|
||||
for g in group_ids:
|
||||
if g in groups:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
"""Gets all the groups of the provided member"""
|
||||
|
||||
|
||||
def create_group_membership_checker() -> GroupMembershipChecker:
|
||||
config = InstanceConfig.fetch()
|
||||
if config.group_membership:
|
||||
return StaticGroupMembership(config.group_membership)
|
||||
else:
|
||||
return AzureADGroupMembership()
|
||||
|
||||
|
||||
class AzureADGroupMembership(GroupMembershipChecker):
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
response = query_microsoft_graph_list(
|
||||
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class StaticGroupMembership(GroupMembershipChecker):
|
||||
def __init__(self, memberships: Dict[str, List[UUID]]):
|
||||
self.memberships = memberships
|
||||
|
||||
def get_groups(self, member_id: UUID) -> List[UUID]:
|
||||
return self.memberships.get(str(member_id), [])
|
@ -5,35 +5,74 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import urllib
|
||||
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from azure.functions import HttpRequest, HttpResponse
|
||||
from memoization import cached
|
||||
from onefuzztypes.enums import ErrorCode
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.responses import BaseResponse
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .azure.group_membership import create_group_membership_checker
|
||||
from .config import InstanceConfig
|
||||
from .orm import ModelMixin
|
||||
from .request_access import RequestAccess
|
||||
|
||||
# We don't actually use these types at runtime at this time. Rather,
|
||||
# these are used in a bound TypeVar. MyPy suggests to only import these
|
||||
# types during type checking.
|
||||
if TYPE_CHECKING:
|
||||
from onefuzztypes.requests import BaseRequest # noqa: F401
|
||||
from pydantic import BaseModel # noqa: F401
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_rules() -> Optional[RequestAccess]:
|
||||
config = InstanceConfig.fetch()
|
||||
if config.api_access_rules:
|
||||
return RequestAccess.build(config.api_access_rules)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def check_access(req: HttpRequest) -> Optional[Error]:
|
||||
if "ONEFUZZ_AAD_GROUP_ID" in os.environ:
|
||||
message = "ONEFUZZ_AAD_GROUP_ID configuration not supported"
|
||||
logging.error(message)
|
||||
return Error(
|
||||
code=ErrorCode.INVALID_CONFIGURATION,
|
||||
errors=[message],
|
||||
rules = get_rules()
|
||||
|
||||
# Noting to enforce if there are no rules.
|
||||
if not rules:
|
||||
return None
|
||||
|
||||
path = urllib.parse.urlparse(req.url).path
|
||||
rule = rules.get_matching_rules(req.method, path)
|
||||
|
||||
# No restriction defined on this endpoint.
|
||||
if not rule:
|
||||
return None
|
||||
|
||||
member_id = UUID(req.headers["x-ms-client-principal-id"])
|
||||
|
||||
try:
|
||||
membership_checker = create_group_membership_checker()
|
||||
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
|
||||
if not allowed:
|
||||
logging.error(
|
||||
"unauthorized access: %s is not authorized to access in %s",
|
||||
member_id,
|
||||
req.url,
|
||||
)
|
||||
else:
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["not approved to use this endpoint"],
|
||||
)
|
||||
except Exception as e:
|
||||
return Error(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
errors=["unable to interact with graph", str(e)],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from onefuzztypes.models import ApiAccessRule
|
||||
from pydantic import parse_raw_as
|
||||
|
||||
|
||||
class RuleConflictError(Exception):
|
||||
@ -41,7 +40,7 @@ class RequestAccess:
|
||||
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
|
||||
methods = list(map(lambda m: m.upper(), methods))
|
||||
|
||||
segments = path.split("/")
|
||||
segments = [s for s in path.split("/") if s != ""]
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
@ -71,15 +70,14 @@ class RequestAccess:
|
||||
for method in methods:
|
||||
current_node.rules[method] = rules
|
||||
|
||||
def get_matching_rules(self, method: str, path: str) -> Rules:
|
||||
def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
|
||||
method = method.upper()
|
||||
segments = path.split("/")
|
||||
segments = [s for s in path.split("/") if s != ""]
|
||||
current_node = self.root
|
||||
current_rule = None
|
||||
|
||||
if method in current_node.rules:
|
||||
current_rule = current_node.rules[method]
|
||||
else:
|
||||
current_rule = RequestAccess.Rules()
|
||||
|
||||
current_segment_index = 0
|
||||
|
||||
@ -98,17 +96,13 @@ class RequestAccess:
|
||||
return current_rule
|
||||
|
||||
@classmethod
|
||||
def parse_rules(cls, rules_data: str) -> "RequestAccess":
|
||||
rules = parse_raw_as(List[ApiAccessRule], rules_data)
|
||||
return cls.build(rules)
|
||||
|
||||
@classmethod
|
||||
def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess":
|
||||
def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
|
||||
request_access = RequestAccess()
|
||||
for rule in rules:
|
||||
for endpoint in rules:
|
||||
rule = rules[endpoint]
|
||||
request_access.__add_url__(
|
||||
rule.methods,
|
||||
rule.endpoint,
|
||||
endpoint,
|
||||
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
|
||||
)
|
||||
|
||||
|
144
src/api-service/functional_tests/api_restriction_test.py
Normal file
144
src/api-service/functional_tests/api_restriction_test.py
Normal file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
||||
from azure.cli.core import get_default_cli
|
||||
from onefuzz.api import Onefuzz
|
||||
from onefuzztypes.models import ApiAccessRule
|
||||
|
||||
|
||||
def az_cli(args: List[str]) -> Any:
|
||||
cli = get_default_cli()
|
||||
cli.logging_cls
|
||||
cli.invoke(args, out_file=open(os.devnull, "w"))
|
||||
if cli.result.result:
|
||||
return cli.result.result
|
||||
elif cli.result.error:
|
||||
raise cli.result.error
|
||||
|
||||
|
||||
class APIRestrictionTests:
|
||||
def __init__(
|
||||
self, resource_group: str = None, onefuzz_config_path: str = None
|
||||
) -> None:
|
||||
self.onefuzz = Onefuzz(config_path=onefuzz_config_path)
|
||||
self.intial_config = self.onefuzz.instance_config.get()
|
||||
|
||||
self.instance_name = urlparse(self.onefuzz.config().endpoint).netloc.split(".")[
|
||||
0
|
||||
]
|
||||
if resource_group:
|
||||
self.resource_group = resource_group
|
||||
else:
|
||||
self.resource_group = self.instance_name
|
||||
|
||||
def restore_config(self) -> None:
|
||||
self.onefuzz.instance_config.update(self.intial_config)
|
||||
|
||||
def assign(self, group_id: UUID, member_id: UUID) -> None:
|
||||
instance_config = self.onefuzz.instance_config.get()
|
||||
if instance_config.group_membership is None:
|
||||
instance_config.group_membership = {}
|
||||
|
||||
if member_id not in instance_config.group_membership:
|
||||
instance_config.group_membership[member_id] = []
|
||||
|
||||
if group_id not in instance_config.group_membership[member_id]:
|
||||
instance_config.group_membership[member_id].append(group_id)
|
||||
|
||||
self.onefuzz.instance_config.update(instance_config)
|
||||
|
||||
def assign_current_user(self, group_id: UUID) -> None:
|
||||
onefuzz_service_appId = az_cli(
|
||||
[
|
||||
"ad",
|
||||
"signed-in-user",
|
||||
"show",
|
||||
]
|
||||
)
|
||||
member_id = UUID(onefuzz_service_appId["objectId"])
|
||||
print(f"adding user {member_id}")
|
||||
self.assign(group_id, member_id)
|
||||
|
||||
def test_restriction_on_current_user(self) -> None:
|
||||
|
||||
print("Checking that the current user can get jobs")
|
||||
self.onefuzz.jobs.list()
|
||||
|
||||
print("Creating test group")
|
||||
group_id = uuid.uuid4()
|
||||
|
||||
print("Adding restriction to the jobs endpoint")
|
||||
instance_config = self.onefuzz.instance_config.get()
|
||||
if instance_config.api_access_rules is None:
|
||||
instance_config.api_access_rules = {}
|
||||
|
||||
instance_config.api_access_rules["/api/jobs"] = ApiAccessRule(
|
||||
allowed_groups=[group_id],
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
self.onefuzz.instance_config.update(instance_config)
|
||||
restart_instance(self.instance_name, self.resource_group)
|
||||
time.sleep(20)
|
||||
print("Checking that the current user cannot get jobs")
|
||||
|
||||
try:
|
||||
self.onefuzz.jobs.list()
|
||||
failed = False
|
||||
except Exception:
|
||||
failed = True
|
||||
pass
|
||||
|
||||
if not failed:
|
||||
raise Exception("Current user was able to get jobs")
|
||||
|
||||
print("Assigning current user to test group")
|
||||
self.assign_current_user(group_id)
|
||||
restart_instance(self.instance_name, self.resource_group)
|
||||
time.sleep(20)
|
||||
|
||||
print("Checking that the current user can get jobs")
|
||||
self.onefuzz.jobs.list()
|
||||
|
||||
|
||||
def restart_instance(instance_name: str, resource_group: str) -> None:
|
||||
print("Restarting instance")
|
||||
az_cli(
|
||||
[
|
||||
"functionapp",
|
||||
"restart",
|
||||
"--name",
|
||||
f"{instance_name}",
|
||||
"--resource-group",
|
||||
f"{resource_group}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config_path", default=None)
|
||||
parser.add_argument("--resource_group", default=None)
|
||||
args = parser.parse_args()
|
||||
tester = APIRestrictionTests(args.resource_group, args.config_path)
|
||||
|
||||
try:
|
||||
print("test current user restriction")
|
||||
tester.test_restriction_on_current_user()
|
||||
finally:
|
||||
tester.restore_config()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
4
src/api-service/functional_tests/requirements.txt
Normal file
4
src/api-service/functional_tests/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
../../cli
|
||||
../../pytypes
|
||||
azure-cli-core==2.27.2
|
||||
azure-cli==2.27.2
|
38
src/api-service/tests/test_group_membership.py
Normal file
38
src/api-service/tests/test_group_membership.py
Normal file
@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from __app__.onefuzzlib.azure.group_membership import (
|
||||
GroupMembershipChecker,
|
||||
StaticGroupMembership,
|
||||
)
|
||||
|
||||
|
||||
class TestRequestAccess(unittest.TestCase):
|
||||
def test_empty(self) -> None:
|
||||
group_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
checker: GroupMembershipChecker = StaticGroupMembership({})
|
||||
|
||||
self.assertFalse(checker.is_member([group_id], user_id))
|
||||
self.assertTrue(checker.is_member([user_id], user_id))
|
||||
|
||||
def test_matching_user_id(self) -> None:
|
||||
group_id = uuid.uuid4()
|
||||
user_id1 = uuid.uuid4()
|
||||
user_id2 = uuid.uuid4()
|
||||
|
||||
checker: GroupMembershipChecker = StaticGroupMembership(
|
||||
{str(user_id1): [group_id]}
|
||||
)
|
||||
self.assertTrue(checker.is_member([user_id1], user_id1))
|
||||
self.assertFalse(checker.is_member([user_id1], user_id2))
|
||||
|
||||
def test_user_in_group(self) -> None:
|
||||
group_id1 = uuid.uuid4()
|
||||
group_id2 = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
checker: GroupMembershipChecker = StaticGroupMembership(
|
||||
{str(user_id): [group_id1]}
|
||||
)
|
||||
self.assertTrue(checker.is_member([group_id1], user_id))
|
||||
self.assertFalse(checker.is_member([group_id2], user_id))
|
@ -8,61 +8,60 @@ from __app__.onefuzzlib.request_access import RequestAccess, RuleConflictError
|
||||
|
||||
class TestRequestAccess(unittest.TestCase):
|
||||
def test_empty(self) -> None:
|
||||
request_access1 = RequestAccess.build([])
|
||||
request_access1 = RequestAccess.build({})
|
||||
rules1 = request_access1.get_matching_rules("get", "a/b/c")
|
||||
|
||||
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing")
|
||||
self.assertEqual(rules1, None, "expected nothing")
|
||||
|
||||
guid2 = uuid.uuid4()
|
||||
request_access1 = RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"a/b/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[guid2],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
rules1 = request_access1.get_matching_rules("get", "")
|
||||
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing")
|
||||
self.assertEqual(rules1, None, "expected nothing")
|
||||
|
||||
def test_exact_match(self) -> None:
|
||||
|
||||
guid1 = uuid.uuid4()
|
||||
|
||||
request_access = RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"a/b/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[guid1],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
rules1 = request_access.get_matching_rules("get", "a/b/c")
|
||||
rules2 = request_access.get_matching_rules("get", "b/b/e")
|
||||
|
||||
assert rules1 is not None
|
||||
self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups")
|
||||
self.assertEqual(rules1.allowed_groups_ids[0], guid1)
|
||||
|
||||
self.assertEqual(len(rules2.allowed_groups_ids), 0, "expected nothing")
|
||||
self.assertEqual(rules2, None, "expected nothing")
|
||||
|
||||
def test_wildcard(self) -> None:
|
||||
guid1 = uuid.uuid4()
|
||||
|
||||
request_access = RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"b/*/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="b/*/c",
|
||||
allowed_groups=[guid1],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
rules = request_access.get_matching_rules("get", "b/b/c")
|
||||
|
||||
assert rules is not None
|
||||
self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups")
|
||||
self.assertEqual(rules.allowed_groups_ids[0], guid1)
|
||||
|
||||
@ -71,18 +70,16 @@ class TestRequestAccess(unittest.TestCase):
|
||||
|
||||
try:
|
||||
RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"a/b/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[guid1],
|
||||
),
|
||||
ApiAccessRule(
|
||||
"a/b/c/": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[],
|
||||
),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
self.fail("this is expected to fail")
|
||||
@ -95,22 +92,21 @@ class TestRequestAccess(unittest.TestCase):
|
||||
guid2 = uuid.uuid4()
|
||||
|
||||
request_access = RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"a/*/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/*/c",
|
||||
allowed_groups=[guid1],
|
||||
),
|
||||
ApiAccessRule(
|
||||
"a/b/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[guid2],
|
||||
),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
rules = request_access.get_matching_rules("get", "a/b/c")
|
||||
|
||||
assert rules is not None
|
||||
self.assertEqual(
|
||||
rules.allowed_groups_ids[0],
|
||||
guid2,
|
||||
@ -125,36 +121,36 @@ class TestRequestAccess(unittest.TestCase):
|
||||
guid3 = uuid.uuid4()
|
||||
|
||||
request_access = RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"a/b/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[guid1],
|
||||
),
|
||||
ApiAccessRule(
|
||||
"f/*/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="f/*/c",
|
||||
allowed_groups=[guid2],
|
||||
),
|
||||
ApiAccessRule(
|
||||
"a/b": ApiAccessRule(
|
||||
methods=["post"],
|
||||
endpoint="a/b",
|
||||
allowed_groups=[guid3],
|
||||
),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
rules1 = request_access.get_matching_rules("get", "a/b/c/d")
|
||||
assert rules1 is not None
|
||||
self.assertEqual(
|
||||
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
|
||||
)
|
||||
|
||||
rules2 = request_access.get_matching_rules("get", "f/b/c/d")
|
||||
assert rules2 is not None
|
||||
self.assertEqual(
|
||||
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c"
|
||||
)
|
||||
|
||||
rules3 = request_access.get_matching_rules("post", "a/b/c/d")
|
||||
assert rules3 is not None
|
||||
self.assertEqual(
|
||||
rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b"
|
||||
)
|
||||
@ -165,26 +161,26 @@ class TestRequestAccess(unittest.TestCase):
|
||||
guid2 = uuid.uuid4()
|
||||
|
||||
request_access = RequestAccess.build(
|
||||
[
|
||||
ApiAccessRule(
|
||||
{
|
||||
"a/b/c": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c",
|
||||
allowed_groups=[guid1],
|
||||
),
|
||||
ApiAccessRule(
|
||||
"a/b/c/d": ApiAccessRule(
|
||||
methods=["get"],
|
||||
endpoint="a/b/c/d",
|
||||
allowed_groups=[guid2],
|
||||
),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
rules1 = request_access.get_matching_rules("get", "a/b/c")
|
||||
assert rules1 is not None
|
||||
self.assertEqual(
|
||||
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
|
||||
)
|
||||
|
||||
rules2 = request_access.get_matching_rules("get", "a/b/c/d")
|
||||
assert rules2 is not None
|
||||
self.assertEqual(
|
||||
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d"
|
||||
)
|
||||
|
@ -266,7 +266,7 @@
|
||||
"value": "[parameters('owner')]"
|
||||
}
|
||||
],
|
||||
"linuxFxVersion": "Python|3.7",
|
||||
"linuxFxVersion": "Python|3.8",
|
||||
"alwaysOn": true,
|
||||
"defaultDocuments": [],
|
||||
"httpLoggingEnabled": true,
|
||||
|
@ -838,10 +838,15 @@ class AzureVmExtensionConfig(BaseModel):
|
||||
|
||||
class ApiAccessRule(BaseModel):
|
||||
methods: List[str]
|
||||
endpoint: str
|
||||
allowed_groups: List[UUID]
|
||||
|
||||
|
||||
Endpoint = str
|
||||
# json dumps doesn't support UUID as dictionary key
|
||||
PrincipalID = str
|
||||
GroupId = UUID
|
||||
|
||||
|
||||
class InstanceConfig(BaseModel):
|
||||
# initial set of admins can only be set during deployment.
|
||||
# if admins are set, only admins can update instance configs.
|
||||
@ -857,6 +862,8 @@ class InstanceConfig(BaseModel):
|
||||
)
|
||||
extensions: Optional[AzureVmExtensionConfig]
|
||||
proxy_vm_sku: str = Field(default="Standard_B2s")
|
||||
api_access_rules: Optional[Dict[Endpoint, ApiAccessRule]] = None
|
||||
group_membership: Optional[Dict[PrincipalID, List[GroupId]]] = None
|
||||
|
||||
def update(self, config: "InstanceConfig") -> None:
|
||||
for field in config.__fields__:
|
||||
|
Reference in New Issue
Block a user