support multiple corpus accounts (#334)

Add support for sharding across multiple storage accounts for blob containers used for corpus management.

Things to note:

1. Additional storage accounts must be in the same resource group, support the "blob" endpoint, and have the tag `storage_type` with the value `corpus`.  A utility is provided (`src/utils/add-corpus-storage-accounts`), which adds storage accounts. 
2. If any secondary storage accounts exist, they are used by default for containers.
3. Storage account names are cached in memory the Azure Function instance forever.   Upon adding new storage accounts, the app needs to be restarted to pick up the new accounts.
This commit is contained in:
bmc-msft
2021-01-06 18:11:39 -05:00
committed by GitHub
parent f345bd239d
commit 3b26ffef65
29 changed files with 496 additions and 179 deletions

View File

@ -4,81 +4,134 @@
# Licensed under the MIT License.
import datetime
import logging
import os
import urllib.parse
from enum import Enum
from typing import Any, Dict, Optional, Union, cast
from typing import Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions
from azure.storage.blob import BlobPermissions, BlockBlobService, ContainerPermissions
from memoization import cached
from onefuzztypes.primitives import Container
from .creds import get_blob_service, get_func_storage, get_fuzz_storage
from .storage import (
StorageType,
choose_account,
get_accounts,
get_storage_account_name_key,
)
class StorageType(Enum):
corpus = "corpus"
config = "config"
@cached
def get_blob_service(account_id: str) -> BlockBlobService:
logging.debug("getting blob container (account_id: %s)", account_id)
account_name, account_key = get_storage_account_name_key(account_id)
service = BlockBlobService(account_name=account_name, account_key=account_key)
return service
def get_account_id_by_type(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
account_id = get_fuzz_storage()
elif storage_type == StorageType.config:
account_id = get_func_storage()
else:
raise NotImplementedError
return account_id
def get_service_by_container(
container: Container, storage_type: StorageType
) -> Optional[BlockBlobService]:
account = get_account_by_container(container, storage_type)
if account is None:
return None
service = get_blob_service(account)
return service
@cached(ttl=5)
def get_blob_service_by_type(storage_type: StorageType) -> Any:
account_id = get_account_id_by_type(storage_type)
return get_blob_service(account_id)
@cached(ttl=5)
def container_exists(name: str, storage_type: StorageType) -> bool:
def container_exists_on_account(container: Container, account_id: str) -> bool:
try:
get_blob_service_by_type(storage_type).get_container_properties(name)
get_blob_service(account_id).get_container_properties(container)
return True
except AzureHttpError:
return False
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
return {
x.name: x.metadata
for x in get_blob_service_by_type(storage_type).list_containers(
include_metadata=True
)
if not x.name.startswith("$")
}
def get_container_metadata(
name: str, storage_type: StorageType
) -> Optional[Dict[str, str]]:
def container_metadata(container: Container, account: str) -> Optional[Dict[str, str]]:
try:
result = get_blob_service_by_type(storage_type).get_container_metadata(name)
result = get_blob_service(account).get_container_metadata(container)
return cast(Dict[str, str], result)
except AzureHttpError:
pass
return None
def create_container(
name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]]
def get_account_by_container(
container: Container, storage_type: StorageType
) -> Optional[str]:
try:
get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
except AzureHttpError:
# azure storage already logs errors
accounts = get_accounts(storage_type)
# check secondary accounts first by searching in reverse.
#
# By implementation, the primary account is specified first, followed by
# any secondary accounts.
#
# Secondary accounts, if they exist, are preferred for containers and have
# increased IOP rates, this should be a slight optimization
for account in reversed(accounts):
if container_exists_on_account(container, account):
return account
return None
def container_exists(container: Container, storage_type: StorageType) -> bool:
return get_account_by_container(container, storage_type) is not None
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
containers: Dict[str, Dict[str, str]] = {}
for account_id in get_accounts(storage_type):
containers.update(
{
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(
include_metadata=True
)
}
)
return containers
def get_container_metadata(
container: Container, storage_type: StorageType
) -> Optional[Dict[str, str]]:
account = get_account_by_container(container, storage_type)
if account is None:
return None
return get_container_sas_url(
name,
storage_type,
return container_metadata(container, account)
def create_container(
container: Container,
storage_type: StorageType,
metadata: Optional[Dict[str, str]],
) -> Optional[str]:
service = get_service_by_container(container, storage_type)
if service is None:
account = choose_account(storage_type)
service = get_blob_service(account)
try:
service.create_container(container, metadata=metadata)
except AzureHttpError as err:
logging.error(
(
"unable to create container. account: %s "
"container: %s metadata: %s - %s"
),
account,
container,
metadata,
err,
)
return None
return get_container_sas_url_service(
container,
service,
read=True,
add=True,
create=True,
@ -88,17 +141,19 @@ def create_container(
)
def delete_container(name: str, storage_type: StorageType) -> bool:
try:
return bool(get_blob_service_by_type(storage_type).delete_container(name))
except AzureHttpError:
# azure storage already logs errors
return False
def delete_container(container: Container, storage_type: StorageType) -> bool:
accounts = get_accounts(storage_type)
for account in accounts:
service = get_blob_service(account)
if bool(service.delete_container(container)):
return True
return False
def get_container_sas_url(
container: str,
storage_type: StorageType,
def get_container_sas_url_service(
container: Container,
service: BlockBlobService,
*,
read: bool = False,
add: bool = False,
@ -107,7 +162,6 @@ def get_container_sas_url(
delete: bool = False,
list: bool = False,
) -> str:
service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list)
@ -120,8 +174,35 @@ def get_container_sas_url(
return str(url)
def get_container_sas_url(
container: Container,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
list: bool = False,
) -> str:
service = get_service_by_container(container, storage_type)
if not service:
raise Exception("unable to create container sas for missing container")
return get_container_sas_url_service(
container,
service,
read=read,
add=add,
create=create,
write=write,
delete=delete,
list=list,
)
def get_file_sas_url(
container: str,
container: Container,
name: str,
storage_type: StorageType,
*,
@ -135,7 +216,10 @@ def get_file_sas_url(
hours: int = 0,
minutes: int = 0,
) -> str:
service = get_blob_service_by_type(storage_type)
service = get_service_by_container(container, storage_type)
if not service:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes
)
@ -150,18 +234,28 @@ def get_file_sas_url(
def save_blob(
container: str, name: str, data: Union[str, bytes], storage_type: StorageType
container: Container,
name: str,
data: Union[str, bytes],
storage_type: StorageType,
) -> None:
service = get_blob_service_by_type(storage_type)
service.create_container(container)
service = get_service_by_container(container, storage_type)
if not service:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
if isinstance(data, str):
service.create_blob_from_text(container, name, data)
elif isinstance(data, bytes):
service.create_blob_from_bytes(container, name, data)
def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]:
service = get_blob_service_by_type(storage_type)
def get_blob(
container: Container, name: str, storage_type: StorageType
) -> Optional[bytes]:
service = get_service_by_container(container, storage_type)
if not service:
return None
try:
blob = service.get_blob_to_bytes(container, name).content
return cast(bytes, blob)
@ -169,8 +263,11 @@ def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[b
return None
def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type)
def blob_exists(container: Container, name: str, storage_type: StorageType) -> bool:
service = get_service_by_container(container, storage_type)
if not service:
return False
try:
service.get_blob_properties(container, name)
return True
@ -178,8 +275,11 @@ def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
return False
def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type)
def delete_blob(container: Container, name: str, storage_type: StorageType) -> bool:
service = get_service_by_container(container, storage_type)
if not service:
return False
try:
service.delete_blob(container, name)
return True
@ -187,7 +287,7 @@ def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
return False
def auth_download_url(container: str, filename: str) -> str:
def auth_download_url(container: Container, filename: str) -> str:
instance = os.environ["ONEFUZZ_INSTANCE"]
return "%s/api/download?%s" % (
instance,

View File

@ -3,9 +3,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
from typing import Any, List, Optional, Tuple
from typing import Any, List
from uuid import UUID
from azure.cli.core import CLIError
@ -13,12 +12,11 @@ from azure.common.client_factory import get_client_from_cli_profile
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.storage import StorageManagementClient
from azure.mgmt.subscription import SubscriptionClient
from azure.storage.blob import BlockBlobService
from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id
from onefuzztypes.primitives import Container
from .monkeypatch import allow_more_workers, reduce_logging
@ -35,34 +33,14 @@ def mgmt_client_factory(client_class: Any) -> Any:
try:
return get_client_from_cli_profile(client_class)
except CLIError:
if issubclass(client_class, SubscriptionClient):
return client_class(get_msi())
else:
return client_class(get_msi(), get_subscription())
pass
except OSError:
pass
@cached
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]:
db_client = mgmt_client_factory(StorageManagementClient)
if account_id is None:
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
resource = parse_resource_id(account_id)
key = (
db_client.storage_accounts.list_keys(
resource["resource_group"], resource["name"]
)
.keys[0]
.value
)
return resource["name"], key
@cached
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
service = BlockBlobService(account_name=name, account_key=key)
return service
if issubclass(client_class, SubscriptionClient):
return client_class(get_msi())
else:
return client_class(get_msi(), get_subscription())
@cached
@ -92,16 +70,6 @@ def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"]
# @cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
# @cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached
def get_instance_name() -> str:
return os.environ["ONEFUZZ_INSTANCE_NAME"]
@ -114,9 +82,10 @@ def get_instance_url() -> str:
@cached
def get_instance_id() -> UUID:
from .containers import StorageType, get_blob
from .containers import get_blob
from .storage import StorageType
blob = get_blob("base-config", "instance_id", StorageType.config)
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
if blob is None:
raise Exception("missing instance_id")
return UUID(blob.decode())

View File

@ -19,8 +19,7 @@ from azure.storage.queue import (
from memoization import cached
from pydantic import BaseModel
from .containers import StorageType, get_account_id_by_type
from .creds import get_storage_account_name_key
from .storage import StorageType, get_primary_account, get_storage_account_name_key
QueueNameType = Union[str, UUID]
@ -29,7 +28,7 @@ DEFAULT_TTL = -1
@cached(ttl=60)
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
account_id = get_account_id_by_type(storage_type)
account_id = get_primary_account(storage_type)
logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name
@ -50,7 +49,7 @@ def get_queue_sas(
update: bool = False,
process: bool = False,
) -> str:
account_id = get_account_id_by_type(storage_type)
account_id = get_primary_account(storage_type)
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
import random
from enum import Enum
from typing import List, Tuple
from azure.mgmt.storage import StorageManagementClient
from memoization import cached
from msrestazure.tools import parse_resource_id
from .creds import get_base_resource_group, mgmt_client_factory
class StorageType(Enum):
corpus = "corpus"
config = "config"
@cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached
def get_primary_account(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
# see #322 for discussion about typing
return get_fuzz_storage()
elif storage_type == StorageType.config:
# see #322 for discussion about typing
return get_func_storage()
raise NotImplementedError
@cached
def get_accounts(storage_type: StorageType) -> List[str]:
if storage_type == StorageType.corpus:
return corpus_accounts()
elif storage_type == StorageType.config:
return [get_func_storage()]
else:
raise NotImplementedError
@cached
def get_storage_account_name_key(account_id: str) -> Tuple[str, str]:
client = mgmt_client_factory(StorageManagementClient)
resource = parse_resource_id(account_id)
key = (
client.storage_accounts.list_keys(resource["resource_group"], resource["name"])
.keys[0]
.value
)
return resource["name"], key
def choose_account(storage_type: StorageType) -> str:
accounts = get_accounts(storage_type)
if not accounts:
raise Exception(f"no storage accounts for {storage_type}")
if len(accounts) == 1:
return accounts[0]
# Use a random secondary storage account if any are available. This
# reduces IOP contention for the Storage Queues, which are only available
# on primary accounts
return random.choice(accounts[1:])
@cached
def corpus_accounts() -> List[str]:
skip = get_func_storage()
results = [get_fuzz_storage()]
client = mgmt_client_factory(StorageManagementClient)
group = get_base_resource_group()
for account in client.storage_accounts.list_by_resource_group(group):
# protection from someone adding the corpus tag to the config account
if account.id == skip:
continue
if account.id in results:
continue
if account.primary_endpoints.blob is None:
continue
if (
"storage_type" not in account.tags
or account.tags["storage_type"] != "corpus"
):
continue
results.append(account.id)
logging.info("corpus accounts: %s", corpus_accounts)
return results

View File

@ -10,7 +10,7 @@ from typing import Optional
from azure.cosmosdb.table import TableService
from memoization import cached
from .creds import get_storage_account_name_key
from .storage import get_storage_account_name_key
@cached(ttl=60)