diff --git a/src/api-service/__app__/onefuzzlib/azure/creds.py b/src/api-service/__app__/onefuzzlib/azure/creds.py index e3669f131..c4bbc92c0 100644 --- a/src/api-service/__app__/onefuzzlib/azure/creds.py +++ b/src/api-service/__app__/onefuzzlib/azure/creds.py @@ -3,10 +3,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import functools +import logging import os -from typing import Any, List +from typing import Any, Callable, List, TypeVar, cast from uuid import UUID +from azure.core.exceptions import ClientAuthenticationError from azure.graphrbac import GraphRbacManagementClient from azure.graphrbac.models import CheckGroupMembershipParameters from azure.identity import DefaultAzureCredential @@ -136,3 +139,48 @@ def get_scaleset_principal_id() -> UUID: @cached def get_keyvault_client(vault_url: str) -> SecretClient: return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential()) + + +def clear_azure_client_cache() -> None: + # clears the memoization of the Azure clients. + + from .compute import get_compute_client + from .containers import get_blob_service + from .network_mgmt_client import get_network_client + from .storage import get_mgmt_client + + # currently memoization.cache does not project the wrapped function's types. + # As a workaround, CI comments out the `cached` wrapper, then runs the type + # validation. This enables calling the wrapper's clear_cache if it's not + # disabled. + for func in [ + get_msi, + get_identity, + get_compute_client, + get_blob_service, + get_network_client, + get_mgmt_client, + ]: + clear_func = getattr(func, "clear_cache", None) + if clear_func is not None: + clear_func() + + +T = TypeVar("T", bound=Callable[..., Any]) + + +class retry_on_auth_failure: + def __call__(self, func: T) -> T: + @functools.wraps(func) + def decorated(*args, **kwargs): # type: ignore + try: + return func(*args, **kwargs) + except ClientAuthenticationError as err: + logging.warning( + "clearing authentication cache after auth failure: %s", err + ) + + clear_azure_client_cache() + return func(*args, **kwargs) + + return cast(T, decorated) diff --git a/src/api-service/__app__/onefuzzlib/azure/vmss.py b/src/api-service/__app__/onefuzzlib/azure/vmss.py index 6269fab03..14cd0f359 100644 --- a/src/api-service/__app__/onefuzzlib/azure/vmss.py +++ b/src/api-service/__app__/onefuzzlib/azure/vmss.py @@ -26,10 +26,15 @@ from onefuzztypes.models import Error from onefuzztypes.primitives import Region from .compute import get_compute_client -from .creds import get_base_resource_group, get_scaleset_identity_resource_path +from .creds import ( + get_base_resource_group, + get_scaleset_identity_resource_path, + retry_on_auth_failure, +) from .image import get_os +@retry_on_auth_failure() def list_vmss(name: UUID) -> Optional[List[str]]: resource_group = get_base_resource_group() client = get_compute_client() @@ -47,6 +52,7 @@ def list_vmss(name: UUID) -> Optional[List[str]]: return None +@retry_on_auth_failure() def delete_vmss(name: UUID) -> bool: resource_group = get_base_resource_group() compute_client = get_compute_client() @@ -63,6 +69,7 @@ def delete_vmss(name: UUID) -> bool: return bool(response.status() == "Succeeded") +@retry_on_auth_failure() def get_vmss(name: UUID) -> Optional[Any]: resource_group = get_base_resource_group() logging.debug("getting vm: %s", name) @@ -75,6 +82,7 @@ def get_vmss(name: UUID) -> Optional[Any]: return None +@retry_on_auth_failure() def resize_vmss(name: UUID, capacity: int) -> None: check_can_update(name) @@ -94,6 +102,7 @@ def resize_vmss(name: UUID, capacity: int) -> None: ) +@retry_on_auth_failure() def get_vmss_size(name: UUID) -> Optional[int]: vmss = get_vmss(name) if vmss is None: @@ -101,6 +110,7 @@ def get_vmss_size(name: UUID) -> Optional[int]: return cast(int, vmss.sku.capacity) +@retry_on_auth_failure() def list_instance_ids(name: UUID) -> Dict[UUID, str]: logging.debug("get instance IDs for scaleset: %s", name) resource_group = get_base_resource_group() @@ -114,10 +124,12 @@ def list_instance_ids(name: UUID) -> Dict[UUID, str]: results[UUID(instance.vm_id)] = cast(str, instance.instance_id) except (ResourceNotFoundError, CloudError): logging.debug("vm does not exist %s", name) + return results @cached(ttl=60) +@retry_on_auth_failure() def get_instance_id(name: UUID, vm_id: UUID) -> Union[str, Error]: resource_group = get_base_resource_group() logging.info("get instance ID for scaleset node: %s:%s", name, vm_id) @@ -151,6 +163,7 @@ def check_can_update(name: UUID) -> Any: return vmss +@retry_on_auth_failure() def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]: check_can_update(name) @@ -175,6 +188,7 @@ def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]: return None +@retry_on_auth_failure() def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]: check_can_update(name) @@ -201,6 +215,7 @@ def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]: return None +@retry_on_auth_failure() def update_extensions(name: UUID, extensions: List[Any]) -> None: check_can_update(name) @@ -215,6 +230,7 @@ def update_extensions(name: UUID, extensions: List[Any]) -> None: logging.info("VM extensions updated: %s", name) +@retry_on_auth_failure() def create_vmss( location: Region, name: UUID, @@ -382,6 +398,7 @@ def create_vmss( @cached(ttl=60) +@retry_on_auth_failure() def list_available_skus(location: str) -> List[str]: compute_client = get_compute_client()