handle azure-mgmt expired auth tokens by clearing the client cache and retrying (#1099)

In order to reduce how frequently the IMS is hit from the service, the service caches the azure-mgmt clients between API calls.  While the management APIs should have some amount of authentication expiration redundancy built in, not all of them do.

This is seen with `ClientAuthenticationError`, most often with the nested exception record of `ExpiredAuthenticationToken`.

This wraps all of the compute layer functionality with a wrapper that checks if there has been an exception, and retries the request.
This commit is contained in:
bmc-msft
2021-07-22 14:01:02 -04:00
committed by GitHub
parent 3289644d2b
commit ee3d0871f2
2 changed files with 67 additions and 2 deletions

View File

@ -3,10 +3,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import functools
import logging
import os
from typing import Any, List
from typing import Any, Callable, List, TypeVar, cast
from uuid import UUID
from azure.core.exceptions import ClientAuthenticationError
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.identity import DefaultAzureCredential
@ -136,3 +139,48 @@ def get_scaleset_principal_id() -> UUID:
@cached
def get_keyvault_client(vault_url: str) -> SecretClient:
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
def clear_azure_client_cache() -> None:
# clears the memoization of the Azure clients.
from .compute import get_compute_client
from .containers import get_blob_service
from .network_mgmt_client import get_network_client
from .storage import get_mgmt_client
# currently memoization.cache does not project the wrapped function's types.
# As a workaround, CI comments out the `cached` wrapper, then runs the type
# validation. This enables calling the wrapper's clear_cache if it's not
# disabled.
for func in [
get_msi,
get_identity,
get_compute_client,
get_blob_service,
get_network_client,
get_mgmt_client,
]:
clear_func = getattr(func, "clear_cache", None)
if clear_func is not None:
clear_func()
T = TypeVar("T", bound=Callable[..., Any])
class retry_on_auth_failure:
def __call__(self, func: T) -> T:
@functools.wraps(func)
def decorated(*args, **kwargs): # type: ignore
try:
return func(*args, **kwargs)
except ClientAuthenticationError as err:
logging.warning(
"clearing authentication cache after auth failure: %s", err
)
clear_azure_client_cache()
return func(*args, **kwargs)
return cast(T, decorated)

View File

@ -26,10 +26,15 @@ from onefuzztypes.models import Error
from onefuzztypes.primitives import Region
from .compute import get_compute_client
from .creds import get_base_resource_group, get_scaleset_identity_resource_path
from .creds import (
get_base_resource_group,
get_scaleset_identity_resource_path,
retry_on_auth_failure,
)
from .image import get_os
@retry_on_auth_failure()
def list_vmss(name: UUID) -> Optional[List[str]]:
resource_group = get_base_resource_group()
client = get_compute_client()
@ -47,6 +52,7 @@ def list_vmss(name: UUID) -> Optional[List[str]]:
return None
@retry_on_auth_failure()
def delete_vmss(name: UUID) -> bool:
resource_group = get_base_resource_group()
compute_client = get_compute_client()
@ -63,6 +69,7 @@ def delete_vmss(name: UUID) -> bool:
return bool(response.status() == "Succeeded")
@retry_on_auth_failure()
def get_vmss(name: UUID) -> Optional[Any]:
resource_group = get_base_resource_group()
logging.debug("getting vm: %s", name)
@ -75,6 +82,7 @@ def get_vmss(name: UUID) -> Optional[Any]:
return None
@retry_on_auth_failure()
def resize_vmss(name: UUID, capacity: int) -> None:
check_can_update(name)
@ -94,6 +102,7 @@ def resize_vmss(name: UUID, capacity: int) -> None:
)
@retry_on_auth_failure()
def get_vmss_size(name: UUID) -> Optional[int]:
vmss = get_vmss(name)
if vmss is None:
@ -101,6 +110,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
return cast(int, vmss.sku.capacity)
@retry_on_auth_failure()
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
logging.debug("get instance IDs for scaleset: %s", name)
resource_group = get_base_resource_group()
@ -114,10 +124,12 @@ def list_instance_ids(name: UUID) -> Dict[UUID, str]:
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
except (ResourceNotFoundError, CloudError):
logging.debug("vm does not exist %s", name)
return results
@cached(ttl=60)
@retry_on_auth_failure()
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
resource_group = get_base_resource_group()
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
@ -151,6 +163,7 @@ def check_can_update(name: UUID) -> Any:
return vmss
@retry_on_auth_failure()
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
check_can_update(name)
@ -175,6 +188,7 @@ def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
return None
@retry_on_auth_failure()
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
check_can_update(name)
@ -201,6 +215,7 @@ def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
return None
@retry_on_auth_failure()
def update_extensions(name: UUID, extensions: List[Any]) -> None:
check_can_update(name)
@ -215,6 +230,7 @@ def update_extensions(name: UUID, extensions: List[Any]) -> None:
logging.info("VM extensions updated: %s", name)
@retry_on_auth_failure()
def create_vmss(
location: Region,
name: UUID,
@ -382,6 +398,7 @@ def create_vmss(
@cached(ttl=60)
@retry_on_auth_failure()
def list_available_skus(location: str) -> List[str]:
compute_client = get_compute_client()