mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-16 20:08:09 +00:00
support multiple corpus accounts (#334)
Add support for sharding across multiple storage accounts for blob containers used for corpus management. Things to note: 1. Additional storage accounts must be in the same resource group, support the "blob" endpoint, and have the tag `storage_type` with the value `corpus`. A utility is provided (`src/utils/add-corpus-storage-accounts`), which adds storage accounts. 2. If any secondary storage accounts exist, they are used by default for containers. 3. Storage account names are cached in memory the Azure Function instance forever. Upon adding new storage accounts, the app needs to be restarted to pick up the new accounts.
This commit is contained in:
@ -4,81 +4,134 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import urllib.parse
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union, cast
|
||||
from typing import Dict, Optional, Union, cast
|
||||
|
||||
from azure.common import AzureHttpError, AzureMissingResourceHttpError
|
||||
from azure.storage.blob import BlobPermissions, ContainerPermissions
|
||||
from azure.storage.blob import BlobPermissions, BlockBlobService, ContainerPermissions
|
||||
from memoization import cached
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from .creds import get_blob_service, get_func_storage, get_fuzz_storage
|
||||
from .storage import (
|
||||
StorageType,
|
||||
choose_account,
|
||||
get_accounts,
|
||||
get_storage_account_name_key,
|
||||
)
|
||||
|
||||
|
||||
class StorageType(Enum):
|
||||
corpus = "corpus"
|
||||
config = "config"
|
||||
@cached
|
||||
def get_blob_service(account_id: str) -> BlockBlobService:
|
||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
||||
account_name, account_key = get_storage_account_name_key(account_id)
|
||||
service = BlockBlobService(account_name=account_name, account_key=account_key)
|
||||
return service
|
||||
|
||||
|
||||
def get_account_id_by_type(storage_type: StorageType) -> str:
|
||||
if storage_type == StorageType.corpus:
|
||||
account_id = get_fuzz_storage()
|
||||
elif storage_type == StorageType.config:
|
||||
account_id = get_func_storage()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return account_id
|
||||
def get_service_by_container(
|
||||
container: Container, storage_type: StorageType
|
||||
) -> Optional[BlockBlobService]:
|
||||
account = get_account_by_container(container, storage_type)
|
||||
if account is None:
|
||||
return None
|
||||
service = get_blob_service(account)
|
||||
return service
|
||||
|
||||
|
||||
@cached(ttl=5)
|
||||
def get_blob_service_by_type(storage_type: StorageType) -> Any:
|
||||
account_id = get_account_id_by_type(storage_type)
|
||||
return get_blob_service(account_id)
|
||||
|
||||
|
||||
@cached(ttl=5)
|
||||
def container_exists(name: str, storage_type: StorageType) -> bool:
|
||||
def container_exists_on_account(container: Container, account_id: str) -> bool:
|
||||
try:
|
||||
get_blob_service_by_type(storage_type).get_container_properties(name)
|
||||
get_blob_service(account_id).get_container_properties(container)
|
||||
return True
|
||||
except AzureHttpError:
|
||||
return False
|
||||
|
||||
|
||||
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
|
||||
return {
|
||||
x.name: x.metadata
|
||||
for x in get_blob_service_by_type(storage_type).list_containers(
|
||||
include_metadata=True
|
||||
)
|
||||
if not x.name.startswith("$")
|
||||
}
|
||||
|
||||
|
||||
def get_container_metadata(
|
||||
name: str, storage_type: StorageType
|
||||
) -> Optional[Dict[str, str]]:
|
||||
def container_metadata(container: Container, account: str) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
result = get_blob_service_by_type(storage_type).get_container_metadata(name)
|
||||
result = get_blob_service(account).get_container_metadata(container)
|
||||
return cast(Dict[str, str], result)
|
||||
except AzureHttpError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def create_container(
|
||||
name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]]
|
||||
def get_account_by_container(
|
||||
container: Container, storage_type: StorageType
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
accounts = get_accounts(storage_type)
|
||||
|
||||
# check secondary accounts first by searching in reverse.
|
||||
#
|
||||
# By implementation, the primary account is specified first, followed by
|
||||
# any secondary accounts.
|
||||
#
|
||||
# Secondary accounts, if they exist, are preferred for containers and have
|
||||
# increased IOP rates, this should be a slight optimization
|
||||
for account in reversed(accounts):
|
||||
if container_exists_on_account(container, account):
|
||||
return account
|
||||
return None
|
||||
|
||||
|
||||
def container_exists(container: Container, storage_type: StorageType) -> bool:
|
||||
return get_account_by_container(container, storage_type) is not None
|
||||
|
||||
|
||||
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
|
||||
containers: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
for account_id in get_accounts(storage_type):
|
||||
containers.update(
|
||||
{
|
||||
x.name: x.metadata
|
||||
for x in get_blob_service(account_id).list_containers(
|
||||
include_metadata=True
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return containers
|
||||
|
||||
|
||||
def get_container_metadata(
|
||||
container: Container, storage_type: StorageType
|
||||
) -> Optional[Dict[str, str]]:
|
||||
account = get_account_by_container(container, storage_type)
|
||||
if account is None:
|
||||
return None
|
||||
|
||||
return get_container_sas_url(
|
||||
name,
|
||||
storage_type,
|
||||
return container_metadata(container, account)
|
||||
|
||||
|
||||
def create_container(
|
||||
container: Container,
|
||||
storage_type: StorageType,
|
||||
metadata: Optional[Dict[str, str]],
|
||||
) -> Optional[str]:
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if service is None:
|
||||
account = choose_account(storage_type)
|
||||
service = get_blob_service(account)
|
||||
try:
|
||||
service.create_container(container, metadata=metadata)
|
||||
except AzureHttpError as err:
|
||||
logging.error(
|
||||
(
|
||||
"unable to create container. account: %s "
|
||||
"container: %s metadata: %s - %s"
|
||||
),
|
||||
account,
|
||||
container,
|
||||
metadata,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
|
||||
return get_container_sas_url_service(
|
||||
container,
|
||||
service,
|
||||
read=True,
|
||||
add=True,
|
||||
create=True,
|
||||
@ -88,17 +141,19 @@ def create_container(
|
||||
)
|
||||
|
||||
|
||||
def delete_container(name: str, storage_type: StorageType) -> bool:
|
||||
try:
|
||||
return bool(get_blob_service_by_type(storage_type).delete_container(name))
|
||||
except AzureHttpError:
|
||||
# azure storage already logs errors
|
||||
return False
|
||||
def delete_container(container: Container, storage_type: StorageType) -> bool:
|
||||
accounts = get_accounts(storage_type)
|
||||
for account in accounts:
|
||||
service = get_blob_service(account)
|
||||
if bool(service.delete_container(container)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_container_sas_url(
|
||||
container: str,
|
||||
storage_type: StorageType,
|
||||
def get_container_sas_url_service(
|
||||
container: Container,
|
||||
service: BlockBlobService,
|
||||
*,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
@ -107,7 +162,6 @@ def get_container_sas_url(
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
) -> str:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
permission = ContainerPermissions(read, add, create, write, delete, list)
|
||||
|
||||
@ -120,8 +174,35 @@ def get_container_sas_url(
|
||||
return str(url)
|
||||
|
||||
|
||||
def get_container_sas_url(
|
||||
container: Container,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
read: bool = False,
|
||||
add: bool = False,
|
||||
create: bool = False,
|
||||
write: bool = False,
|
||||
delete: bool = False,
|
||||
list: bool = False,
|
||||
) -> str:
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if not service:
|
||||
raise Exception("unable to create container sas for missing container")
|
||||
|
||||
return get_container_sas_url_service(
|
||||
container,
|
||||
service,
|
||||
read=read,
|
||||
add=add,
|
||||
create=create,
|
||||
write=write,
|
||||
delete=delete,
|
||||
list=list,
|
||||
)
|
||||
|
||||
|
||||
def get_file_sas_url(
|
||||
container: str,
|
||||
container: Container,
|
||||
name: str,
|
||||
storage_type: StorageType,
|
||||
*,
|
||||
@ -135,7 +216,10 @@ def get_file_sas_url(
|
||||
hours: int = 0,
|
||||
minutes: int = 0,
|
||||
) -> str:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if not service:
|
||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
||||
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(
|
||||
days=days, hours=hours, minutes=minutes
|
||||
)
|
||||
@ -150,18 +234,28 @@ def get_file_sas_url(
|
||||
|
||||
|
||||
def save_blob(
|
||||
container: str, name: str, data: Union[str, bytes], storage_type: StorageType
|
||||
container: Container,
|
||||
name: str,
|
||||
data: Union[str, bytes],
|
||||
storage_type: StorageType,
|
||||
) -> None:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
service.create_container(container)
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if not service:
|
||||
raise Exception("unable to find container: %s - %s" % (container, storage_type))
|
||||
|
||||
if isinstance(data, str):
|
||||
service.create_blob_from_text(container, name, data)
|
||||
elif isinstance(data, bytes):
|
||||
service.create_blob_from_bytes(container, name, data)
|
||||
|
||||
|
||||
def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
def get_blob(
|
||||
container: Container, name: str, storage_type: StorageType
|
||||
) -> Optional[bytes]:
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if not service:
|
||||
return None
|
||||
|
||||
try:
|
||||
blob = service.get_blob_to_bytes(container, name).content
|
||||
return cast(bytes, blob)
|
||||
@ -169,8 +263,11 @@ def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[b
|
||||
return None
|
||||
|
||||
|
||||
def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
def blob_exists(container: Container, name: str, storage_type: StorageType) -> bool:
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if not service:
|
||||
return False
|
||||
|
||||
try:
|
||||
service.get_blob_properties(container, name)
|
||||
return True
|
||||
@ -178,8 +275,11 @@ def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
|
||||
service = get_blob_service_by_type(storage_type)
|
||||
def delete_blob(container: Container, name: str, storage_type: StorageType) -> bool:
|
||||
service = get_service_by_container(container, storage_type)
|
||||
if not service:
|
||||
return False
|
||||
|
||||
try:
|
||||
service.delete_blob(container, name)
|
||||
return True
|
||||
@ -187,7 +287,7 @@ def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def auth_download_url(container: str, filename: str) -> str:
|
||||
def auth_download_url(container: Container, filename: str) -> str:
|
||||
instance = os.environ["ONEFUZZ_INSTANCE"]
|
||||
return "%s/api/download?%s" % (
|
||||
instance,
|
||||
|
@ -3,9 +3,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from azure.cli.core import CLIError
|
||||
@ -13,12 +12,11 @@ from azure.common.client_factory import get_client_from_cli_profile
|
||||
from azure.graphrbac import GraphRbacManagementClient
|
||||
from azure.graphrbac.models import CheckGroupMembershipParameters
|
||||
from azure.mgmt.resource import ResourceManagementClient
|
||||
from azure.mgmt.storage import StorageManagementClient
|
||||
from azure.mgmt.subscription import SubscriptionClient
|
||||
from azure.storage.blob import BlockBlobService
|
||||
from memoization import cached
|
||||
from msrestazure.azure_active_directory import MSIAuthentication
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.primitives import Container
|
||||
|
||||
from .monkeypatch import allow_more_workers, reduce_logging
|
||||
|
||||
@ -35,34 +33,14 @@ def mgmt_client_factory(client_class: Any) -> Any:
|
||||
try:
|
||||
return get_client_from_cli_profile(client_class)
|
||||
except CLIError:
|
||||
if issubclass(client_class, SubscriptionClient):
|
||||
return client_class(get_msi())
|
||||
else:
|
||||
return client_class(get_msi(), get_subscription())
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@cached
|
||||
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]:
|
||||
db_client = mgmt_client_factory(StorageManagementClient)
|
||||
if account_id is None:
|
||||
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
resource = parse_resource_id(account_id)
|
||||
key = (
|
||||
db_client.storage_accounts.list_keys(
|
||||
resource["resource_group"], resource["name"]
|
||||
)
|
||||
.keys[0]
|
||||
.value
|
||||
)
|
||||
return resource["name"], key
|
||||
|
||||
|
||||
@cached
|
||||
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
|
||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
service = BlockBlobService(account_name=name, account_key=key)
|
||||
return service
|
||||
if issubclass(client_class, SubscriptionClient):
|
||||
return client_class(get_msi())
|
||||
else:
|
||||
return client_class(get_msi(), get_subscription())
|
||||
|
||||
|
||||
@cached
|
||||
@ -92,16 +70,6 @@ def get_insights_appid() -> str:
|
||||
return os.environ["APPINSIGHTS_APPID"]
|
||||
|
||||
|
||||
# @cached
|
||||
def get_fuzz_storage() -> str:
|
||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
|
||||
|
||||
# @cached
|
||||
def get_func_storage() -> str:
|
||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_instance_name() -> str:
|
||||
return os.environ["ONEFUZZ_INSTANCE_NAME"]
|
||||
@ -114,9 +82,10 @@ def get_instance_url() -> str:
|
||||
|
||||
@cached
|
||||
def get_instance_id() -> UUID:
|
||||
from .containers import StorageType, get_blob
|
||||
from .containers import get_blob
|
||||
from .storage import StorageType
|
||||
|
||||
blob = get_blob("base-config", "instance_id", StorageType.config)
|
||||
blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
|
||||
if blob is None:
|
||||
raise Exception("missing instance_id")
|
||||
return UUID(blob.decode())
|
||||
|
@ -19,8 +19,7 @@ from azure.storage.queue import (
|
||||
from memoization import cached
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .containers import StorageType, get_account_id_by_type
|
||||
from .creds import get_storage_account_name_key
|
||||
from .storage import StorageType, get_primary_account, get_storage_account_name_key
|
||||
|
||||
QueueNameType = Union[str, UUID]
|
||||
|
||||
@ -29,7 +28,7 @@ DEFAULT_TTL = -1
|
||||
|
||||
@cached(ttl=60)
|
||||
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
|
||||
account_id = get_account_id_by_type(storage_type)
|
||||
account_id = get_primary_account(storage_type)
|
||||
logging.debug("getting blob container (account_id: %s)", account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
account_url = "https://%s.queue.core.windows.net" % name
|
||||
@ -50,7 +49,7 @@ def get_queue_sas(
|
||||
update: bool = False,
|
||||
process: bool = False,
|
||||
) -> str:
|
||||
account_id = get_account_id_by_type(storage_type)
|
||||
account_id = get_primary_account(storage_type)
|
||||
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
|
||||
name, key = get_storage_account_name_key(account_id)
|
||||
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
|
108
src/api-service/__app__/onefuzzlib/azure/storage.py
Normal file
108
src/api-service/__app__/onefuzzlib/azure/storage.py
Normal file
@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from enum import Enum
|
||||
from typing import List, Tuple
|
||||
|
||||
from azure.mgmt.storage import StorageManagementClient
|
||||
from memoization import cached
|
||||
from msrestazure.tools import parse_resource_id
|
||||
|
||||
from .creds import get_base_resource_group, mgmt_client_factory
|
||||
|
||||
|
||||
class StorageType(Enum):
|
||||
corpus = "corpus"
|
||||
config = "config"
|
||||
|
||||
|
||||
@cached
|
||||
def get_fuzz_storage() -> str:
|
||||
return os.environ["ONEFUZZ_DATA_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_func_storage() -> str:
|
||||
return os.environ["ONEFUZZ_FUNC_STORAGE"]
|
||||
|
||||
|
||||
@cached
|
||||
def get_primary_account(storage_type: StorageType) -> str:
|
||||
if storage_type == StorageType.corpus:
|
||||
# see #322 for discussion about typing
|
||||
return get_fuzz_storage()
|
||||
elif storage_type == StorageType.config:
|
||||
# see #322 for discussion about typing
|
||||
return get_func_storage()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@cached
|
||||
def get_accounts(storage_type: StorageType) -> List[str]:
|
||||
if storage_type == StorageType.corpus:
|
||||
return corpus_accounts()
|
||||
elif storage_type == StorageType.config:
|
||||
return [get_func_storage()]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@cached
|
||||
def get_storage_account_name_key(account_id: str) -> Tuple[str, str]:
|
||||
client = mgmt_client_factory(StorageManagementClient)
|
||||
resource = parse_resource_id(account_id)
|
||||
key = (
|
||||
client.storage_accounts.list_keys(resource["resource_group"], resource["name"])
|
||||
.keys[0]
|
||||
.value
|
||||
)
|
||||
return resource["name"], key
|
||||
|
||||
|
||||
def choose_account(storage_type: StorageType) -> str:
|
||||
accounts = get_accounts(storage_type)
|
||||
if not accounts:
|
||||
raise Exception(f"no storage accounts for {storage_type}")
|
||||
|
||||
if len(accounts) == 1:
|
||||
return accounts[0]
|
||||
|
||||
# Use a random secondary storage account if any are available. This
|
||||
# reduces IOP contention for the Storage Queues, which are only available
|
||||
# on primary accounts
|
||||
return random.choice(accounts[1:])
|
||||
|
||||
|
||||
@cached
|
||||
def corpus_accounts() -> List[str]:
|
||||
skip = get_func_storage()
|
||||
results = [get_fuzz_storage()]
|
||||
|
||||
client = mgmt_client_factory(StorageManagementClient)
|
||||
group = get_base_resource_group()
|
||||
for account in client.storage_accounts.list_by_resource_group(group):
|
||||
# protection from someone adding the corpus tag to the config account
|
||||
if account.id == skip:
|
||||
continue
|
||||
|
||||
if account.id in results:
|
||||
continue
|
||||
|
||||
if account.primary_endpoints.blob is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
"storage_type" not in account.tags
|
||||
or account.tags["storage_type"] != "corpus"
|
||||
):
|
||||
continue
|
||||
|
||||
results.append(account.id)
|
||||
|
||||
logging.info("corpus accounts: %s", corpus_accounts)
|
||||
return results
|
@ -10,7 +10,7 @@ from typing import Optional
|
||||
from azure.cosmosdb.table import TableService
|
||||
from memoization import cached
|
||||
|
||||
from .creds import get_storage_account_name_key
|
||||
from .storage import get_storage_account_name_key
|
||||
|
||||
|
||||
@cached(ttl=60)
|
||||
|
Reference in New Issue
Block a user