Group membership check (#1074)

This commit is contained in:
Cheick Keita
2021-11-22 14:06:03 -08:00
committed by GitHub
parent 8ff6509e61
commit aa74550160
12 changed files with 429 additions and 78 deletions

View File

@ -233,7 +233,7 @@ jobs:
pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl pip install ${GITHUB_WORKSPACE}/artifacts/sdk/onefuzztypes-*.whl
pip install -r __app__/requirements.txt pip install -r __app__/requirements.txt
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
pytest pytest tests
flake8 . flake8 .
bandit -r ./__app__/ bandit -r ./__app__/
black ./__app__/ ./tests --check black ./__app__/ ./tests --check

View File

@ -663,6 +663,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"definitions": { "definitions": {
"ApiAccessRule": {
"properties": {
"allowed_groups": {
"items": {
"format": "uuid",
"type": "string"
},
"title": "Allowed Groups",
"type": "array"
},
"methods": {
"items": {
"type": "string"
},
"title": "Methods",
"type": "array"
}
},
"required": [
"methods",
"allowed_groups"
],
"title": "ApiAccessRule",
"type": "object"
},
"AzureMonitorExtensionConfig": { "AzureMonitorExtensionConfig": {
"properties": { "properties": {
"config_version": { "config_version": {
@ -757,9 +782,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Allowed Aad Tenants", "title": "Allowed Aad Tenants",
"type": "array" "type": "array"
}, },
"api_access_rules": {
"additionalProperties": {
"$ref": "#/definitions/ApiAccessRule"
},
"title": "Api Access Rules",
"type": "object"
},
"extensions": { "extensions": {
"$ref": "#/definitions/AzureVmExtensionConfig" "$ref": "#/definitions/AzureVmExtensionConfig"
}, },
"group_membership": {
"additionalProperties": {
"items": {
"format": "uuid",
"type": "string"
},
"type": "array"
},
"title": "Group Membership",
"type": "object"
},
"network_config": { "network_config": {
"$ref": "#/definitions/NetworkConfig" "$ref": "#/definitions/NetworkConfig"
}, },
@ -4933,6 +4976,31 @@ Each event will be submitted via HTTP POST to the user provided URL.
```json ```json
{ {
"definitions": { "definitions": {
"ApiAccessRule": {
"properties": {
"allowed_groups": {
"items": {
"format": "uuid",
"type": "string"
},
"title": "Allowed Groups",
"type": "array"
},
"methods": {
"items": {
"type": "string"
},
"title": "Methods",
"type": "array"
}
},
"required": [
"methods",
"allowed_groups"
],
"title": "ApiAccessRule",
"type": "object"
},
"Architecture": { "Architecture": {
"description": "An enumeration.", "description": "An enumeration.",
"enum": [ "enum": [
@ -5856,9 +5924,27 @@ Each event will be submitted via HTTP POST to the user provided URL.
"title": "Allowed Aad Tenants", "title": "Allowed Aad Tenants",
"type": "array" "type": "array"
}, },
"api_access_rules": {
"additionalProperties": {
"$ref": "#/definitions/ApiAccessRule"
},
"title": "Api Access Rules",
"type": "object"
},
"extensions": { "extensions": {
"$ref": "#/definitions/AzureVmExtensionConfig" "$ref": "#/definitions/AzureVmExtensionConfig"
}, },
"group_membership": {
"additionalProperties": {
"items": {
"format": "uuid",
"type": "string"
},
"type": "array"
},
"title": "Group Membership",
"type": "object"
},
"network_config": { "network_config": {
"$ref": "#/definitions/NetworkConfig" "$ref": "#/definitions/NetworkConfig"
}, },

View File

@ -168,14 +168,6 @@ def query_microsoft_graph_list(
raise GraphQueryError("Expected data containing a list of values", None) raise GraphQueryError("Expected data containing a list of values", None)
def is_member_of(group_id: str, member_id: str) -> bool:
body = {"groupIds": [group_id]}
response = query_microsoft_graph_list(
method="POST", resource=f"users/{member_id}/checkMemberGroups", body=body
)
return group_id in response
@cached @cached
def get_scaleset_identity_resource_path() -> str: def get_scaleset_identity_resource_path() -> str:
scaleset_id_name = "%s-scalesetid" % get_instance_name() scaleset_id_name = "%s-scalesetid" % get_instance_name()

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Dict, List, Protocol
from uuid import UUID
from ..config import InstanceConfig
from .creds import query_microsoft_graph_list
class GroupMembershipChecker(Protocol):
def is_member(self, group_ids: List[UUID], member_id: UUID) -> bool:
"""Check if member is part of at least one of the groups"""
if member_id in group_ids:
return True
groups = self.get_groups(member_id)
for g in group_ids:
if g in groups:
return True
return False
def get_groups(self, member_id: UUID) -> List[UUID]:
"""Gets all the groups of the provided member"""
def create_group_membership_checker() -> GroupMembershipChecker:
config = InstanceConfig.fetch()
if config.group_membership:
return StaticGroupMembership(config.group_membership)
else:
return AzureADGroupMembership()
class AzureADGroupMembership(GroupMembershipChecker):
def get_groups(self, member_id: UUID) -> List[UUID]:
response = query_microsoft_graph_list(
method="GET", resource=f"users/{member_id}/transitiveMemberOf"
)
return response
class StaticGroupMembership(GroupMembershipChecker):
def __init__(self, memberships: Dict[str, List[UUID]]):
self.memberships = memberships
def get_groups(self, member_id: UUID) -> List[UUID]:
return self.memberships.get(str(member_id), [])

View File

@ -5,35 +5,74 @@
import json import json
import logging import logging
import os import urllib
from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union from typing import TYPE_CHECKING, Optional, Sequence, Type, TypeVar, Union
from uuid import UUID from uuid import UUID
from azure.functions import HttpRequest, HttpResponse from azure.functions import HttpRequest, HttpResponse
from memoization import cached
from onefuzztypes.enums import ErrorCode from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.responses import BaseResponse from onefuzztypes.responses import BaseResponse
from pydantic import BaseModel # noqa: F401
from pydantic import ValidationError from pydantic import ValidationError
from .azure.group_membership import create_group_membership_checker
from .config import InstanceConfig
from .orm import ModelMixin from .orm import ModelMixin
from .request_access import RequestAccess
# We don't actually use these types at runtime at this time. Rather, # We don't actually use these types at runtime at this time. Rather,
# these are used in a bound TypeVar. MyPy suggests to only import these # these are used in a bound TypeVar. MyPy suggests to only import these
# types during type checking. # types during type checking.
if TYPE_CHECKING: if TYPE_CHECKING:
from onefuzztypes.requests import BaseRequest # noqa: F401 from onefuzztypes.requests import BaseRequest # noqa: F401
from pydantic import BaseModel # noqa: F401
@cached(ttl=60)
def get_rules() -> Optional[RequestAccess]:
config = InstanceConfig.fetch()
if config.api_access_rules:
return RequestAccess.build(config.api_access_rules)
else:
return None
def check_access(req: HttpRequest) -> Optional[Error]: def check_access(req: HttpRequest) -> Optional[Error]:
if "ONEFUZZ_AAD_GROUP_ID" in os.environ: rules = get_rules()
message = "ONEFUZZ_AAD_GROUP_ID configuration not supported"
logging.error(message) # Noting to enforce if there are no rules.
return Error( if not rules:
code=ErrorCode.INVALID_CONFIGURATION, return None
errors=[message],
path = urllib.parse.urlparse(req.url).path
rule = rules.get_matching_rules(req.method, path)
# No restriction defined on this endpoint.
if not rule:
return None
member_id = UUID(req.headers["x-ms-client-principal-id"])
try:
membership_checker = create_group_membership_checker()
allowed = membership_checker.is_member(rule.allowed_groups_ids, member_id)
if not allowed:
logging.error(
"unauthorized access: %s is not authorized to access in %s",
member_id,
req.url,
) )
else: return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["not approved to use this endpoint"],
)
except Exception as e:
return Error(
code=ErrorCode.UNAUTHORIZED,
errors=["unable to interact with graph", str(e)],
)
return None return None

View File

@ -1,8 +1,7 @@
from typing import Dict, List from typing import Dict, List, Optional
from uuid import UUID from uuid import UUID
from onefuzztypes.models import ApiAccessRule from onefuzztypes.models import ApiAccessRule
from pydantic import parse_raw_as
class RuleConflictError(Exception): class RuleConflictError(Exception):
@ -41,7 +40,7 @@ class RequestAccess:
def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None: def __add_url__(self, methods: List[str], path: str, rules: Rules) -> None:
methods = list(map(lambda m: m.upper(), methods)) methods = list(map(lambda m: m.upper(), methods))
segments = path.split("/") segments = [s for s in path.split("/") if s != ""]
if len(segments) == 0: if len(segments) == 0:
return return
@ -71,15 +70,14 @@ class RequestAccess:
for method in methods: for method in methods:
current_node.rules[method] = rules current_node.rules[method] = rules
def get_matching_rules(self, method: str, path: str) -> Rules: def get_matching_rules(self, method: str, path: str) -> Optional[Rules]:
method = method.upper() method = method.upper()
segments = path.split("/") segments = [s for s in path.split("/") if s != ""]
current_node = self.root current_node = self.root
current_rule = None
if method in current_node.rules: if method in current_node.rules:
current_rule = current_node.rules[method] current_rule = current_node.rules[method]
else:
current_rule = RequestAccess.Rules()
current_segment_index = 0 current_segment_index = 0
@ -98,17 +96,13 @@ class RequestAccess:
return current_rule return current_rule
@classmethod @classmethod
def parse_rules(cls, rules_data: str) -> "RequestAccess": def build(cls, rules: Dict[str, ApiAccessRule]) -> "RequestAccess":
rules = parse_raw_as(List[ApiAccessRule], rules_data)
return cls.build(rules)
@classmethod
def build(cls, rules: List[ApiAccessRule]) -> "RequestAccess":
request_access = RequestAccess() request_access = RequestAccess()
for rule in rules: for endpoint in rules:
rule = rules[endpoint]
request_access.__add_url__( request_access.__add_url__(
rule.methods, rule.methods,
rule.endpoint, endpoint,
RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups), RequestAccess.Rules(allowed_groups_ids=rule.allowed_groups),
) )

View File

@ -0,0 +1,144 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import os
import time
import uuid
from typing import Any, List
from urllib.parse import urlparse
from uuid import UUID
from azure.cli.core import get_default_cli
from onefuzz.api import Onefuzz
from onefuzztypes.models import ApiAccessRule
def az_cli(args: List[str]) -> Any:
cli = get_default_cli()
cli.logging_cls
cli.invoke(args, out_file=open(os.devnull, "w"))
if cli.result.result:
return cli.result.result
elif cli.result.error:
raise cli.result.error
class APIRestrictionTests:
def __init__(
self, resource_group: str = None, onefuzz_config_path: str = None
) -> None:
self.onefuzz = Onefuzz(config_path=onefuzz_config_path)
self.intial_config = self.onefuzz.instance_config.get()
self.instance_name = urlparse(self.onefuzz.config().endpoint).netloc.split(".")[
0
]
if resource_group:
self.resource_group = resource_group
else:
self.resource_group = self.instance_name
def restore_config(self) -> None:
self.onefuzz.instance_config.update(self.intial_config)
def assign(self, group_id: UUID, member_id: UUID) -> None:
instance_config = self.onefuzz.instance_config.get()
if instance_config.group_membership is None:
instance_config.group_membership = {}
if member_id not in instance_config.group_membership:
instance_config.group_membership[member_id] = []
if group_id not in instance_config.group_membership[member_id]:
instance_config.group_membership[member_id].append(group_id)
self.onefuzz.instance_config.update(instance_config)
def assign_current_user(self, group_id: UUID) -> None:
onefuzz_service_appId = az_cli(
[
"ad",
"signed-in-user",
"show",
]
)
member_id = UUID(onefuzz_service_appId["objectId"])
print(f"adding user {member_id}")
self.assign(group_id, member_id)
def test_restriction_on_current_user(self) -> None:
print("Checking that the current user can get jobs")
self.onefuzz.jobs.list()
print("Creating test group")
group_id = uuid.uuid4()
print("Adding restriction to the jobs endpoint")
instance_config = self.onefuzz.instance_config.get()
if instance_config.api_access_rules is None:
instance_config.api_access_rules = {}
instance_config.api_access_rules["/api/jobs"] = ApiAccessRule(
allowed_groups=[group_id],
methods=["GET"],
)
self.onefuzz.instance_config.update(instance_config)
restart_instance(self.instance_name, self.resource_group)
time.sleep(20)
print("Checking that the current user cannot get jobs")
try:
self.onefuzz.jobs.list()
failed = False
except Exception:
failed = True
pass
if not failed:
raise Exception("Current user was able to get jobs")
print("Assigning current user to test group")
self.assign_current_user(group_id)
restart_instance(self.instance_name, self.resource_group)
time.sleep(20)
print("Checking that the current user can get jobs")
self.onefuzz.jobs.list()
def restart_instance(instance_name: str, resource_group: str) -> None:
print("Restarting instance")
az_cli(
[
"functionapp",
"restart",
"--name",
f"{instance_name}",
"--resource-group",
f"{resource_group}",
]
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", default=None)
parser.add_argument("--resource_group", default=None)
args = parser.parse_args()
tester = APIRestrictionTests(args.resource_group, args.config_path)
try:
print("test current user restriction")
tester.test_restriction_on_current_user()
finally:
tester.restore_config()
pass
if __name__ == "__main__":
main()

View File

@ -0,0 +1,4 @@
../../cli
../../pytypes
azure-cli-core==2.27.2
azure-cli==2.27.2

View File

@ -0,0 +1,38 @@
import unittest
import uuid
from __app__.onefuzzlib.azure.group_membership import (
GroupMembershipChecker,
StaticGroupMembership,
)
class TestRequestAccess(unittest.TestCase):
def test_empty(self) -> None:
group_id = uuid.uuid4()
user_id = uuid.uuid4()
checker: GroupMembershipChecker = StaticGroupMembership({})
self.assertFalse(checker.is_member([group_id], user_id))
self.assertTrue(checker.is_member([user_id], user_id))
def test_matching_user_id(self) -> None:
group_id = uuid.uuid4()
user_id1 = uuid.uuid4()
user_id2 = uuid.uuid4()
checker: GroupMembershipChecker = StaticGroupMembership(
{str(user_id1): [group_id]}
)
self.assertTrue(checker.is_member([user_id1], user_id1))
self.assertFalse(checker.is_member([user_id1], user_id2))
def test_user_in_group(self) -> None:
group_id1 = uuid.uuid4()
group_id2 = uuid.uuid4()
user_id = uuid.uuid4()
checker: GroupMembershipChecker = StaticGroupMembership(
{str(user_id): [group_id1]}
)
self.assertTrue(checker.is_member([group_id1], user_id))
self.assertFalse(checker.is_member([group_id2], user_id))

View File

@ -8,61 +8,60 @@ from __app__.onefuzzlib.request_access import RequestAccess, RuleConflictError
class TestRequestAccess(unittest.TestCase): class TestRequestAccess(unittest.TestCase):
def test_empty(self) -> None: def test_empty(self) -> None:
request_access1 = RequestAccess.build([]) request_access1 = RequestAccess.build({})
rules1 = request_access1.get_matching_rules("get", "a/b/c") rules1 = request_access1.get_matching_rules("get", "a/b/c")
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing") self.assertEqual(rules1, None, "expected nothing")
guid2 = uuid.uuid4() guid2 = uuid.uuid4()
request_access1 = RequestAccess.build( request_access1 = RequestAccess.build(
[ {
ApiAccessRule( "a/b/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid2], allowed_groups=[guid2],
) )
] }
) )
rules1 = request_access1.get_matching_rules("get", "") rules1 = request_access1.get_matching_rules("get", "")
self.assertEqual(len(rules1.allowed_groups_ids), 0, "expected nothing") self.assertEqual(rules1, None, "expected nothing")
def test_exact_match(self) -> None: def test_exact_match(self) -> None:
guid1 = uuid.uuid4() guid1 = uuid.uuid4()
request_access = RequestAccess.build( request_access = RequestAccess.build(
[ {
ApiAccessRule( "a/b/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1], allowed_groups=[guid1],
) )
] }
) )
rules1 = request_access.get_matching_rules("get", "a/b/c") rules1 = request_access.get_matching_rules("get", "a/b/c")
rules2 = request_access.get_matching_rules("get", "b/b/e") rules2 = request_access.get_matching_rules("get", "b/b/e")
assert rules1 is not None
self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups") self.assertNotEqual(len(rules1.allowed_groups_ids), 0, "empty allowed groups")
self.assertEqual(rules1.allowed_groups_ids[0], guid1) self.assertEqual(rules1.allowed_groups_ids[0], guid1)
self.assertEqual(len(rules2.allowed_groups_ids), 0, "expected nothing") self.assertEqual(rules2, None, "expected nothing")
def test_wildcard(self) -> None: def test_wildcard(self) -> None:
guid1 = uuid.uuid4() guid1 = uuid.uuid4()
request_access = RequestAccess.build( request_access = RequestAccess.build(
[ {
ApiAccessRule( "b/*/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="b/*/c",
allowed_groups=[guid1], allowed_groups=[guid1],
) )
] }
) )
rules = request_access.get_matching_rules("get", "b/b/c") rules = request_access.get_matching_rules("get", "b/b/c")
assert rules is not None
self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups") self.assertNotEqual(len(rules.allowed_groups_ids), 0, "empty allowed groups")
self.assertEqual(rules.allowed_groups_ids[0], guid1) self.assertEqual(rules.allowed_groups_ids[0], guid1)
@ -71,18 +70,16 @@ class TestRequestAccess(unittest.TestCase):
try: try:
RequestAccess.build( RequestAccess.build(
[ {
ApiAccessRule( "a/b/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1], allowed_groups=[guid1],
), ),
ApiAccessRule( "a/b/c/": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[], allowed_groups=[],
), ),
] }
) )
self.fail("this is expected to fail") self.fail("this is expected to fail")
@ -95,22 +92,21 @@ class TestRequestAccess(unittest.TestCase):
guid2 = uuid.uuid4() guid2 = uuid.uuid4()
request_access = RequestAccess.build( request_access = RequestAccess.build(
[ {
ApiAccessRule( "a/*/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/*/c",
allowed_groups=[guid1], allowed_groups=[guid1],
), ),
ApiAccessRule( "a/b/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid2], allowed_groups=[guid2],
), ),
] }
) )
rules = request_access.get_matching_rules("get", "a/b/c") rules = request_access.get_matching_rules("get", "a/b/c")
assert rules is not None
self.assertEqual( self.assertEqual(
rules.allowed_groups_ids[0], rules.allowed_groups_ids[0],
guid2, guid2,
@ -125,36 +121,36 @@ class TestRequestAccess(unittest.TestCase):
guid3 = uuid.uuid4() guid3 = uuid.uuid4()
request_access = RequestAccess.build( request_access = RequestAccess.build(
[ {
ApiAccessRule( "a/b/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1], allowed_groups=[guid1],
), ),
ApiAccessRule( "f/*/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="f/*/c",
allowed_groups=[guid2], allowed_groups=[guid2],
), ),
ApiAccessRule( "a/b": ApiAccessRule(
methods=["post"], methods=["post"],
endpoint="a/b",
allowed_groups=[guid3], allowed_groups=[guid3],
), ),
] }
) )
rules1 = request_access.get_matching_rules("get", "a/b/c/d") rules1 = request_access.get_matching_rules("get", "a/b/c/d")
assert rules1 is not None
self.assertEqual( self.assertEqual(
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c" rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
) )
rules2 = request_access.get_matching_rules("get", "f/b/c/d") rules2 = request_access.get_matching_rules("get", "f/b/c/d")
assert rules2 is not None
self.assertEqual( self.assertEqual(
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c" rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of f/*/c"
) )
rules3 = request_access.get_matching_rules("post", "a/b/c/d") rules3 = request_access.get_matching_rules("post", "a/b/c/d")
assert rules3 is not None
self.assertEqual( self.assertEqual(
rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b" rules3.allowed_groups_ids[0], guid3, "expected to inherit rule of post a/b"
) )
@ -165,26 +161,26 @@ class TestRequestAccess(unittest.TestCase):
guid2 = uuid.uuid4() guid2 = uuid.uuid4()
request_access = RequestAccess.build( request_access = RequestAccess.build(
[ {
ApiAccessRule( "a/b/c": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c",
allowed_groups=[guid1], allowed_groups=[guid1],
), ),
ApiAccessRule( "a/b/c/d": ApiAccessRule(
methods=["get"], methods=["get"],
endpoint="a/b/c/d",
allowed_groups=[guid2], allowed_groups=[guid2],
), ),
] }
) )
rules1 = request_access.get_matching_rules("get", "a/b/c") rules1 = request_access.get_matching_rules("get", "a/b/c")
assert rules1 is not None
self.assertEqual( self.assertEqual(
rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c" rules1.allowed_groups_ids[0], guid1, "expected to inherit rule of a/b/c"
) )
rules2 = request_access.get_matching_rules("get", "a/b/c/d") rules2 = request_access.get_matching_rules("get", "a/b/c/d")
assert rules2 is not None
self.assertEqual( self.assertEqual(
rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d" rules2.allowed_groups_ids[0], guid2, "expected to inherit rule of a/b/c/d"
) )

View File

@ -266,7 +266,7 @@
"value": "[parameters('owner')]" "value": "[parameters('owner')]"
} }
], ],
"linuxFxVersion": "Python|3.7", "linuxFxVersion": "Python|3.8",
"alwaysOn": true, "alwaysOn": true,
"defaultDocuments": [], "defaultDocuments": [],
"httpLoggingEnabled": true, "httpLoggingEnabled": true,

View File

@ -838,10 +838,15 @@ class AzureVmExtensionConfig(BaseModel):
class ApiAccessRule(BaseModel): class ApiAccessRule(BaseModel):
methods: List[str] methods: List[str]
endpoint: str
allowed_groups: List[UUID] allowed_groups: List[UUID]
Endpoint = str
# json dumps doesn't support UUID as dictionary key
PrincipalID = str
GroupId = UUID
class InstanceConfig(BaseModel): class InstanceConfig(BaseModel):
# initial set of admins can only be set during deployment. # initial set of admins can only be set during deployment.
# if admins are set, only admins can update instance configs. # if admins are set, only admins can update instance configs.
@ -857,6 +862,8 @@ class InstanceConfig(BaseModel):
) )
extensions: Optional[AzureVmExtensionConfig] extensions: Optional[AzureVmExtensionConfig]
proxy_vm_sku: str = Field(default="Standard_B2s") proxy_vm_sku: str = Field(default="Standard_B2s")
api_access_rules: Optional[Dict[Endpoint, ApiAccessRule]] = None
group_membership: Optional[Dict[PrincipalID, List[GroupId]]] = None
def update(self, config: "InstanceConfig") -> None: def update(self, config: "InstanceConfig") -> None:
for field in config.__fields__: for field in config.__fields__: