mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 11:58:09 +00:00
handle azure-mgmt expired auth tokens by clearing the client cache and retrying (#1099)
In order to reduce how frequently the IMS is hit from the service, the service caches the azure-mgmt clients between API calls. While the management APIs should have some amount of authentication expiration redundancy built in, not all of them do. This is seen with `ClientAuthenticationError`, most often with the nested exception record of `ExpiredAuthenticationToken`. This wraps all of the compute layer functionality with a wrapper that checks if there has been an exception, and retries the request.
This commit is contained in:
@ -3,10 +3,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
from typing import Any, Callable, List, TypeVar, cast
|
||||
from uuid import UUID
|
||||
|
||||
from azure.core.exceptions import ClientAuthenticationError
|
||||
from azure.graphrbac import GraphRbacManagementClient
|
||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||
from azure.identity import DefaultAzureCredential
|
||||
@ -136,3 +139,48 @@ def get_scaleset_principal_id() -> UUID:
|
||||
@cached
|
||||
def get_keyvault_client(vault_url: str) -> SecretClient:
|
||||
return SecretClient(vault_url=vault_url, credential=DefaultAzureCredential())
|
||||
|
||||
|
||||
def clear_azure_client_cache() -> None:
|
||||
# clears the memoization of the Azure clients.
|
||||
|
||||
from .compute import get_compute_client
|
||||
from .containers import get_blob_service
|
||||
from .network_mgmt_client import get_network_client
|
||||
from .storage import get_mgmt_client
|
||||
|
||||
# currently memoization.cache does not project the wrapped function's types.
|
||||
# As a workaround, CI comments out the `cached` wrapper, then runs the type
|
||||
# validation. This enables calling the wrapper's clear_cache if it's not
|
||||
# disabled.
|
||||
for func in [
|
||||
get_msi,
|
||||
get_identity,
|
||||
get_compute_client,
|
||||
get_blob_service,
|
||||
get_network_client,
|
||||
get_mgmt_client,
|
||||
]:
|
||||
clear_func = getattr(func, "clear_cache", None)
|
||||
if clear_func is not None:
|
||||
clear_func()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class retry_on_auth_failure:
|
||||
def __call__(self, func: T) -> T:
|
||||
@functools.wraps(func)
|
||||
def decorated(*args, **kwargs): # type: ignore
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ClientAuthenticationError as err:
|
||||
logging.warning(
|
||||
"clearing authentication cache after auth failure: %s", err
|
||||
)
|
||||
|
||||
clear_azure_client_cache()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return cast(T, decorated)
|
||||
|
Reference in New Issue
Block a user