mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 20:08:09 +00:00
handle azure-mgmt expired auth tokens by clearing the client cache and retrying (#1099)
In order to reduce how frequently the IMS is hit from the service, the service caches the azure-mgmt clients between API calls. While the management APIs should have some amount of authentication expiration redundancy built in, not all of them do. This is seen with `ClientAuthenticationError`, most often with the nested exception record of `ExpiredAuthenticationToken`. This wraps all of the compute layer functionality with a wrapper that checks if there has been an exception, and retries the request.
This commit is contained in:
@ -3,10 +3,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any, Callable, List, TypeVar, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ClientAuthenticationError
|
||||
from azure.graphrbac import GraphRbacManagementClient
|
||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||
from azure.identity import DefaultAzureCredential
|
||||
@ -136,3 +139,48 @@ def get_scaleset_principal_id() -> UUID:
|
||||
@cached
|
||||
def get_keyvault_client(vault_url: str) -> SecretClient:
|
||||
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
||||
|
||||
|
||||
def clear_azure_client_cache() -> None:
|
||||
# clears the memoization of the Azure clients.
|
||||
|
||||
from .compute import get_compute_client
|
||||
from .containers import get_blob_service
|
||||
from .network_mgmt_client import get_network_client
|
||||
from .storage import get_mgmt_client
|
||||
|
||||
# currently memoization.cache does not project the wrapped function's types.
|
||||
# As a workaround, CI comments out the `cached` wrapper, then runs the type
|
||||
# validation. This enables calling the wrapper's clear_cache if it's not
|
||||
# disabled.
|
||||
for func in [
|
||||
get_msi,
|
||||
get_identity,
|
||||
get_compute_client,
|
||||
get_blob_service,
|
||||
get_network_client,
|
||||
get_mgmt_client,
|
||||
]:
|
||||
clear_func = getattr(func, "clear_cache", None)
|
||||
if clear_func is not None:
|
||||
clear_func()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class retry_on_auth_failure:
|
||||
def __call__(self, func: T) -> T:
|
||||
@functools.wraps(func)
|
||||
def decorated(*args, **kwargs): # type: ignore
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ClientAuthenticationError as err:
|
||||
logging.warning(
|
||||
"clearing authentication cache after auth failure: %s", err
|
||||
)
|
||||
|
||||
clear_azure_client_cache()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(T, decorated)
|
||||
|
@ -26,10 +26,15 @@ from onefuzztypes.models import Error
|
||||
from onefuzztypes.primitives import Region
|
||||
|
||||
from .compute import get_compute_client
|
||||
from .creds import get_base_resource_group, get_scaleset_identity_resource_path
|
||||
from .creds import (
|
||||
get_base_resource_group,
|
||||
get_scaleset_identity_resource_path,
|
||||
retry_on_auth_failure,
|
||||
)
|
||||
from .image import get_os
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def list_vmss(name: UUID) -> Optional[List[str]]:
|
||||
resource_group = get_base_resource_group()
|
||||
client = get_compute_client()
|
||||
@ -47,6 +52,7 @@ def list_vmss(name: UUID) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def delete_vmss(name: UUID) -> bool:
|
||||
resource_group = get_base_resource_group()
|
||||
compute_client = get_compute_client()
|
||||
@ -63,6 +69,7 @@ def delete_vmss(name: UUID) -> bool:
|
||||
return bool(response.status() == "Succeeded")
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def get_vmss(name: UUID) -> Optional[Any]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.debug("getting vm: %s", name)
|
||||
@ -75,6 +82,7 @@ def get_vmss(name: UUID) -> Optional[Any]:
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def resize_vmss(name: UUID, capacity: int) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
@ -94,6 +102,7 @@ def resize_vmss(name: UUID, capacity: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
vmss = get_vmss(name)
|
||||
if vmss is None:
|
||||
@ -101,6 +110,7 @@ def get_vmss_size(name: UUID) -> Optional[int]:
|
||||
return cast(int, vmss.sku.capacity)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
logging.debug("get instance IDs for scaleset: %s", name)
|
||||
resource_group = get_base_resource_group()
|
||||
@ -114,10 +124,12 @@ def list_instance_ids(name: UUID) -> Dict[UUID, str]:
|
||||
results[UUID(instance.vm_id)] = cast(str, instance.instance_id)
|
||||
except (ResourceNotFoundError, CloudError):
|
||||
logging.debug("vm does not exist %s", name)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
@retry_on_auth_failure()
|
||||
def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]:
|
||||
resource_group = get_base_resource_group()
|
||||
logging.info("get instance ID for scaleset node: %s:%s", name, vm_id)
|
||||
@ -151,6 +163,7 @@ def check_can_update(name: UUID) -> Any:
|
||||
return vmss
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
@ -175,6 +188,7 @@ def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||
check_can_update(name)
|
||||
|
||||
@ -201,6 +215,7 @@ def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
|
||||
return None
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
||||
check_can_update(name)
|
||||
|
||||
@ -215,6 +230,7 @@ def update_extensions(name: UUID, extensions: List[Any]) -> None:
|
||||
logging.info("VM extensions updated: %s", name)
|
||||
|
||||
|
||||
@retry_on_auth_failure()
|
||||
def create_vmss(
|
||||
location: Region,
|
||||
name: UUID,
|
||||
@ -382,6 +398,7 @@ def create_vmss(
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
@retry_on_auth_failure()
|
||||
def list_available_skus(location: str) -> List[str]:
|
||||
compute_client = get_compute_client()
|
||||
|
||||
|
Reference in New Issue
Block a user