mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 20:08:09 +00:00
handle azure-mgmt expired auth tokens by clearing the client cache and retrying (#1099)
In order to reduce how frequently the IMS is hit from the service, the service caches the azure-mgmt clients between API calls. While the management APIs should have some amount of authentication expiration redundancy built in, not all of them do. This is seen with `ClientAuthenticationError`, most often with the nested exception record of `ExpiredAuthenticationToken`. This wraps all of the compute layer functionality with a wrapper that checks if there has been an exception, and retries the request.
This commit is contained in:
@ -3,10 +3,13 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT License.
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, List
|
from typing import Any, Callable, List, TypeVar, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from azure.core.exceptions import ClientAuthenticationError
|
||||||
from azure.graphrbac import GraphRbacManagementClient
|
from azure.graphrbac import GraphRbacManagementClient
|
||||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||||
from azure.identity import DefaultAzureCredential
|
from azure.identity import DefaultAzureCredential
|
||||||
@ -136,3 +139,48 @@ def get_scaleset_principal_id() -> UUID:
|
|||||||
@cached
|
@cached
|
||||||
def get_keyvault_client(vault_url: str) -> SecretClient:
|
def get_keyvault_client(vault_url: str) -> SecretClient:
|
||||||
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
||||||
|
|
||||||
|
|
||||||
|
def clear_azure_client_cache() -> None:
|
||||||
|
# clears the memoization of the Azure clients.
|
||||||
|
|
||||||
|
from .compute import get_compute_client
|
||||||
|
from .containers import get_blob_service
|
||||||
|
from .network_mgmt_client import get_network_client
|
||||||
|
from .storage import get_mgmt_client
|
||||||
|
|
||||||
|
# currently memoization.cache does not project the wrapped function's types.
|
||||||
|
# As a workaround, CI comments out the `cached` wrapper, then runs the type
|
||||||
|
# validation. This enables calling the wrapper's clear_cache if it's not
|
||||||
|
# disabled.
|
||||||
|
for func in [
|
||||||
|
get_msi,
|
||||||
|
get_identity,
|
||||||
|
get_compute_client,
|
||||||
|
get_blob_service,
|
||||||
|
get_network_client,
|
||||||
|
get_mgmt_client,
|
||||||
|
]:
|
||||||
|
clear_func = getattr(func, "clear_cache", None)
|
||||||
|
if clear_func is not None:
|
||||||
|
clear_func()
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
class retry_on_auth_failure:
|
||||||
|
def __call__(self, func: T) -> T:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def decorated(*args, **kwargs): # type: ignore
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except ClientAuthenticationError as err:
|
||||||
|
logging.warning(
|
||||||
|
"clearing authentication cache after auth failure: %s", err
|
||||||
|
)
|
||||||
|
|
||||||
|
clear_azure_client_cache()
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return cast(T, decorated)
|
||||||
|
@ -26,10 +26,15 @@ from onefuzztypes.models import Error
|
|||||||
from onefuzztypes.primitives import Region
|
from onefuzztypes.primitives import Region
|
||||||
|
|
||||||
from .compute import get_compute_client
|
from .compute import get_compute_client
|
||||||
from .creds import get_base_resource_group, get_scaleset_identity_resource_path
|
from .creds import (
|
||||||
|
get_base_resource_group,
|
||||||
|
get_scaleset_identity_resource_path,
|
||||||
|
retry_on_auth_failure,
|
||||||
|
)
|
||||||
from .image import get_os
|
from .image import get_os
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def list_vmss(name: UUID) -> Optional[List[str]]:
|
def list_vmss(name: UUID) -> Optional[List[str]]:
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
client = get_compute_client()
|
client = get_compute_client()
|
||||||
@ -47,6 +52,7 @@ def list_vmss(name: UUID) -> Optional[List[str]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def delete_vmss(name: UUID) -> bool:
|
def delete_vmss(name: UUID) -> bool:
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
compute_client = get_compute_client()
|
compute_client = get_compute_client()
|
||||||
@ -63,6 +69,7 @@ def delete_vmss(name: UUID) -> bool:
|
|||||||
return bool(response.status() == "Succeeded")
|
return bool(response.status() == "Succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def get_vmss(name: UUID) -> Optional[Any]:
|
def get_vmss(name: UUID) -> Optional[Any]:
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
logging.debug("getting vm: %s", name)
|
logging.debug("getting vm: %s", name)
|
||||||
@ -75,6 +82,7 @@ def get_vmss(name: UUID) -> Optional[Any]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def resize_vmss(name: UUID, capacity: int) -> None:
|
def resize_vmss(name: UUID, capacity: int) -> None:
|
||||||
check_can_update(name)
|
check_can_update(name)
|
||||||
|
|
||||||
@ -94,6 +102,7 @@ def resize_vmss(name: UUID, capacity: int) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def get_vmss_size(name: UUID) -> Optional[int]:
|
def get_vmss_size(name: UUID) -> Optional[int]:
|
||||||
vmss = get_vmss(name)
|
vmss = get_vmss(name)
|
||||||
if vmss is None:
|
if vmss is None:
|
||||||
@ -101,6 +110,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
|
|||||||
return cast(int, vmss.sku.capacity)
|
return cast(int, vmss.sku.capacity)
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||||
logging.debug("get instance IDs for scaleset: %s", name)
|
logging.debug("get instance IDs for scaleset: %s", name)
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
@ -114,10 +124,12 @@ def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
|||||||
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
||||||
except (ResourceNotFoundError, CloudError):
|
except (ResourceNotFoundError, CloudError):
|
||||||
logging.debug("vm does not exist %s", name)
|
logging.debug("vm does not exist %s", name)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
@cached(ttl=60)
|
||||||
|
@retry_on_auth_failure()
|
||||||
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
||||||
resource_group = get_base_resource_group()
|
resource_group = get_base_resource_group()
|
||||||
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
||||||
@ -151,6 +163,7 @@ def check_can_update(name: UUID) -> Any:
|
|||||||
return vmss
|
return vmss
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||||
check_can_update(name)
|
check_can_update(name)
|
||||||
|
|
||||||
@ -175,6 +188,7 @@ def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||||
check_can_update(name)
|
check_can_update(name)
|
||||||
|
|
||||||
@ -201,6 +215,7 @@ def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
||||||
check_can_update(name)
|
check_can_update(name)
|
||||||
|
|
||||||
@ -215,6 +230,7 @@ def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
|||||||
logging.info("VM extensions updated: %s", name)
|
logging.info("VM extensions updated: %s", name)
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_auth_failure()
|
||||||
def create_vmss(
|
def create_vmss(
|
||||||
location: Region,
|
location: Region,
|
||||||
name: UUID,
|
name: UUID,
|
||||||
@ -382,6 +398,7 @@ def create_vmss(
|
|||||||
|
|
||||||
|
|
||||||
@cached(ttl=60)
|
@cached(ttl=60)
|
||||||
|
@retry_on_auth_failure()
|
||||||
def list_available_skus(location: str) -> List[str]:
|
def list_available_skus(location: str) -> List[str]:
|
||||||
compute_client = get_compute_client()
|
compute_client = get_compute_client()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user