Storing the user assigned managed identity in the scaleset table (#255)

This commit is contained in:
Cheick Keita
2020-11-05 15:36:59 -08:00
committed by GitHub
parent b5578381ce
commit bbee84ab1f
3 changed files with 41 additions and 11 deletions

View File

@ -55,14 +55,15 @@ def try_get_token_auth_header(request: func.HttpRequest) -> Union[Error, TokenDa
@cached(ttl=60) @cached(ttl=60)
def is_authorized(token_data: TokenData) -> bool: def is_authorized(token_data: TokenData) -> bool:
# verify object_id against the user assigned managed identity
if get_scaleset_principal_id() == token_data.object_id:
return True
# backward compatibility case for scalesets deployed before the migration # backward compatibility case for scalesets deployed before the migration
# to user assigned managed id # to user assigned managed id
scalesets = Scaleset.get_by_object_id(token_data.object_id) scalesets = Scaleset.get_by_object_id(token_data.object_id)
return len(scalesets) > 0 if len(scalesets) > 0:
return True
# verify object_id against the user assigned managed identity
principal_id: UUID = get_scaleset_principal_id()
return principal_id == token_data.object_id
def verify_token( def verify_token(

View File

@ -151,7 +151,7 @@ def get_scaleset_identity_resource_path() -> str:
@cached @cached
def get_scaleset_principal_id() -> UUID: def get_scaleset_principal_id() -> UUID:
api_version = "2018-11-30" # matches the apiversion in the deplyoment template api_version = "2018-11-30" # matches the apiversion in the deployment template
client = mgmt_client_factory(ResourceManagementClient) client = mgmt_client_factory(ResourceManagementClient)
uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version) uid = client.resources.get_by_id(get_scaleset_identity_resource_path(), api_version)
return UUID(uid.properties["principalId"]) return UUID(uid.properties["principalId"])

View File

@ -5,7 +5,7 @@
import datetime import datetime
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from onefuzztypes.enums import ( from onefuzztypes.enums import (
@ -712,14 +712,43 @@ class Scaleset(BASE_SCALESET, ORMMixin):
logging.info("creating scaleset: %s", self.scaleset_id) logging.info("creating scaleset: %s", self.scaleset_id)
elif vmss.provisioning_state == "Creating": elif vmss.provisioning_state == "Creating":
logging.info("Waiting on scaleset creation: %s", self.scaleset_id) logging.info("Waiting on scaleset creation: %s", self.scaleset_id)
if vmss.identity and vmss.identity.principal_id: self.try_set_identity(vmss)
self.client_object_id = vmss.identity.principal_id
else: else:
logging.info("scaleset running: %s", self.scaleset_id) logging.info("scaleset running: %s", self.scaleset_id)
error = self.try_set_identity(vmss)
if error:
self.state = ScalesetState.creation_failed
self.error = error
else:
self.state = ScalesetState.running self.state = ScalesetState.running
self.client_object_id = vmss.identity.principal_id
self.save() self.save()
def try_set_identity(self, vmss: Any) -> Optional[Error]:
def get_error() -> Error:
return Error(
code=ErrorCode.VM_CREATE_FAILED,
errors=[
"The scaleset is expected to have exactly 1 user assigned identity"
],
)
if self.client_object_id:
return None
if (
vmss.identity
and vmss.identity.user_assigned_identities
and (len(vmss.identity.user_assigned_identities) != 1)
):
return get_error()
user_assinged_identities = list(vmss.identity.user_assigned_identities.values())
if user_assinged_identities[0].principal_id:
self.client_object_id = user_assinged_identities[0].principal_id
return None
else:
return get_error()
# result = 'did I modify the scaleset in azure' # result = 'did I modify the scaleset in azure'
def cleanup_nodes(self) -> bool: def cleanup_nodes(self) -> bool:
if self.state == ScalesetState.halt: if self.state == ScalesetState.halt: