support multiple corpus accounts (#334)

Add support for sharding across multiple storage accounts for blob containers used for corpus management.

Things to note:

1. Additional storage accounts must be in the same resource group, support the "blob" endpoint, and have the tag `storage_type` with the value `corpus`.  A utility is provided (`src/utils/add-corpus-storage-accounts`), which adds storage accounts. 
2. If any secondary storage accounts exist, they are used by default for containers.
3. Storage account names are cached in memory the Azure Function instance forever.   Upon adding new storage accounts, the app needs to be restarted to pick up the new accounts.
This commit is contained in:
bmc-msft
2021-01-06 18:11:39 -05:00
committed by GitHub
parent f345bd239d
commit 3b26ffef65
29 changed files with 496 additions and 179 deletions

View File

@ -12,9 +12,9 @@ from onefuzztypes.models import Error
from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.azure.containers import StorageType
from ..onefuzzlib.azure.creds import get_instance_url from ..onefuzzlib.azure.creds import get_instance_url
from ..onefuzzlib.azure.queue import get_queue_sas from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_agent from ..onefuzzlib.endpoint_authorization import call_if_agent
from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool
from ..onefuzzlib.request import not_ok, ok, parse_uri from ..onefuzzlib.request import not_ok, ok, parse_uri

View File

@ -13,13 +13,13 @@ from onefuzztypes.requests import ContainerCreate, ContainerDelete, ContainerGet
from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase
from ..onefuzzlib.azure.containers import ( from ..onefuzzlib.azure.containers import (
StorageType,
create_container, create_container,
delete_container, delete_container,
get_container_metadata, get_container_metadata,
get_container_sas_url, get_container_sas_url,
get_containers, get_containers,
) )
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import not_ok, ok, parse_request from ..onefuzzlib.request import not_ok, ok, parse_request

View File

@ -8,11 +8,11 @@ from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, FileEntry from onefuzztypes.models import Error, FileEntry
from ..onefuzzlib.azure.containers import ( from ..onefuzzlib.azure.containers import (
StorageType,
blob_exists, blob_exists,
container_exists, container_exists,
get_file_sas_url, get_file_sas_url,
) )
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.request import not_ok, parse_uri, redirect from ..onefuzzlib.request import not_ok, parse_uri, redirect

View File

@ -4,81 +4,134 @@
# Licensed under the MIT License. # Licensed under the MIT License.
import datetime import datetime
import logging
import os import os
import urllib.parse import urllib.parse
from enum import Enum from typing import Dict, Optional, Union, cast
from typing import Any, Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions from azure.storage.blob import BlobPermissions, BlockBlobService, ContainerPermissions
from memoization import cached from memoization import cached
from onefuzztypes.primitives import Container
from .creds import get_blob_service, get_func_storage, get_fuzz_storage from .storage import (
StorageType,
choose_account,
get_accounts,
get_storage_account_name_key,
)
class StorageType(Enum): @cached
corpus = "corpus" def get_blob_service(account_id: str) -> BlockBlobService:
config = "config" logging.debug("getting blob container (account_id: %s)", account_id)
account_name, account_key = get_storage_account_name_key(account_id)
service = BlockBlobService(account_name=account_name, account_key=account_key)
return service
def get_account_id_by_type(storage_type: StorageType) -> str: def get_service_by_container(
if storage_type == StorageType.corpus: container: Container, storage_type: StorageType
account_id = get_fuzz_storage() ) -> Optional[BlockBlobService]:
elif storage_type == StorageType.config: account = get_account_by_container(container, storage_type)
account_id = get_func_storage() if account is None:
else: return None
raise NotImplementedError service = get_blob_service(account)
return account_id return service
@cached(ttl=5) def container_exists_on_account(container: Container, account_id: str) -> bool:
def get_blob_service_by_type(storage_type: StorageType) -> Any:
account_id = get_account_id_by_type(storage_type)
return get_blob_service(account_id)
@cached(ttl=5)
def container_exists(name: str, storage_type: StorageType) -> bool:
try: try:
get_blob_service_by_type(storage_type).get_container_properties(name) get_blob_service(account_id).get_container_properties(container)
return True return True
except AzureHttpError: except AzureHttpError:
return False return False
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]: def container_metadata(container: Container, account: str) -> Optional[Dict[str, str]]:
return {
x.name: x.metadata
for x in get_blob_service_by_type(storage_type).list_containers(
include_metadata=True
)
if not x.name.startswith("$")
}
def get_container_metadata(
name: str, storage_type: StorageType
) -> Optional[Dict[str, str]]:
try: try:
result = get_blob_service_by_type(storage_type).get_container_metadata(name) result = get_blob_service(account).get_container_metadata(container)
return cast(Dict[str, str], result) return cast(Dict[str, str], result)
except AzureHttpError: except AzureHttpError:
pass pass
return None return None
def create_container( def get_account_by_container(
name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]] container: Container, storage_type: StorageType
) -> Optional[str]: ) -> Optional[str]:
try: accounts = get_accounts(storage_type)
get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
except AzureHttpError: # check secondary accounts first by searching in reverse.
# azure storage already logs errors #
# By implementation, the primary account is specified first, followed by
# any secondary accounts.
#
# Secondary accounts, if they exist, are preferred for containers and have
# increased IOP rates, this should be a slight optimization
for account in reversed(accounts):
if container_exists_on_account(container, account):
return account
return None
def container_exists(container: Container, storage_type: StorageType) -> bool:
return get_account_by_container(container, storage_type) is not None
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
containers: Dict[str, Dict[str, str]] = {}
for account_id in get_accounts(storage_type):
containers.update(
{
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(
include_metadata=True
)
}
)
return containers
def get_container_metadata(
container: Container, storage_type: StorageType
) -> Optional[Dict[str, str]]:
account = get_account_by_container(container, storage_type)
if account is None:
return None return None
return get_container_sas_url( return container_metadata(container, account)
name,
storage_type,
def create_container(
container: Container,
storage_type: StorageType,
metadata: Optional[Dict[str, str]],
) -> Optional[str]:
service = get_service_by_container(container, storage_type)
if service is None:
account = choose_account(storage_type)
service = get_blob_service(account)
try:
service.create_container(container, metadata=metadata)
except AzureHttpError as err:
logging.error(
(
"unable to create container. account: %s "
"container: %s metadata: %s - %s"
),
account,
container,
metadata,
err,
)
return None
return get_container_sas_url_service(
container,
service,
read=True, read=True,
add=True, add=True,
create=True, create=True,
@ -88,17 +141,19 @@ def create_container(
) )
def delete_container(name: str, storage_type: StorageType) -> bool: def delete_container(container: Container, storage_type: StorageType) -> bool:
try: accounts = get_accounts(storage_type)
return bool(get_blob_service_by_type(storage_type).delete_container(name)) for account in accounts:
except AzureHttpError: service = get_blob_service(account)
# azure storage already logs errors if bool(service.delete_container(container)):
return False return True
return False
def get_container_sas_url( def get_container_sas_url_service(
container: str, container: Container,
storage_type: StorageType, service: BlockBlobService,
*, *,
read: bool = False, read: bool = False,
add: bool = False, add: bool = False,
@ -107,7 +162,6 @@ def get_container_sas_url(
delete: bool = False, delete: bool = False,
list: bool = False, list: bool = False,
) -> str: ) -> str:
service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30) expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list) permission = ContainerPermissions(read, add, create, write, delete, list)
@ -120,8 +174,35 @@ def get_container_sas_url(
return str(url) return str(url)
def get_container_sas_url(
container: Container,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
write: bool = False,
delete: bool = False,
list: bool = False,
) -> str:
service = get_service_by_container(container, storage_type)
if not service:
raise Exception("unable to create container sas for missing container")
return get_container_sas_url_service(
container,
service,
read=read,
add=add,
create=create,
write=write,
delete=delete,
list=list,
)
def get_file_sas_url( def get_file_sas_url(
container: str, container: Container,
name: str, name: str,
storage_type: StorageType, storage_type: StorageType,
*, *,
@ -135,7 +216,10 @@ def get_file_sas_url(
hours: int = 0, hours: int = 0,
minutes: int = 0, minutes: int = 0,
) -> str: ) -> str:
service = get_blob_service_by_type(storage_type) service = get_service_by_container(container, storage_type)
if not service:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
expiry = datetime.datetime.utcnow() + datetime.timedelta( expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes days=days, hours=hours, minutes=minutes
) )
@ -150,18 +234,28 @@ def get_file_sas_url(
def save_blob( def save_blob(
container: str, name: str, data: Union[str, bytes], storage_type: StorageType container: Container,
name: str,
data: Union[str, bytes],
storage_type: StorageType,
) -> None: ) -> None:
service = get_blob_service_by_type(storage_type) service = get_service_by_container(container, storage_type)
service.create_container(container) if not service:
raise Exception("unable to find container: %s - %s" % (container, storage_type))
if isinstance(data, str): if isinstance(data, str):
service.create_blob_from_text(container, name, data) service.create_blob_from_text(container, name, data)
elif isinstance(data, bytes): elif isinstance(data, bytes):
service.create_blob_from_bytes(container, name, data) service.create_blob_from_bytes(container, name, data)
def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]: def get_blob(
service = get_blob_service_by_type(storage_type) container: Container, name: str, storage_type: StorageType
) -> Optional[bytes]:
service = get_service_by_container(container, storage_type)
if not service:
return None
try: try:
blob = service.get_blob_to_bytes(container, name).content blob = service.get_blob_to_bytes(container, name).content
return cast(bytes, blob) return cast(bytes, blob)
@ -169,8 +263,11 @@ def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[b
return None return None
def blob_exists(container: str, name: str, storage_type: StorageType) -> bool: def blob_exists(container: Container, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type) service = get_service_by_container(container, storage_type)
if not service:
return False
try: try:
service.get_blob_properties(container, name) service.get_blob_properties(container, name)
return True return True
@ -178,8 +275,11 @@ def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
return False return False
def delete_blob(container: str, name: str, storage_type: StorageType) -> bool: def delete_blob(container: Container, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type) service = get_service_by_container(container, storage_type)
if not service:
return False
try: try:
service.delete_blob(container, name) service.delete_blob(container, name)
return True return True
@ -187,7 +287,7 @@ def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
return False return False
def auth_download_url(container: str, filename: str) -> str: def auth_download_url(container: Container, filename: str) -> str:
instance = os.environ["ONEFUZZ_INSTANCE"] instance = os.environ["ONEFUZZ_INSTANCE"]
return "%s/api/download?%s" % ( return "%s/api/download?%s" % (
instance, instance,

View File

@ -3,9 +3,8 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import logging
import os import os
from typing import Any, List, Optional, Tuple from typing import Any, List
from uuid import UUID from uuid import UUID
from azure.cli.core import CLIError from azure.cli.core import CLIError
@ -13,12 +12,11 @@ from azure.common.client_factory import get_client_from_cli_profile
from azure.graphrbac import GraphRbacManagementClient from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import CheckGroupMembershipParameters from azure.graphrbac.models import CheckGroupMembershipParameters
from azure.mgmt.resource import ResourceManagementClient from azure.mgmt.resource import ResourceManagementClient
from azure.mgmt.storage import StorageManagementClient
from azure.mgmt.subscription import SubscriptionClient from azure.mgmt.subscription import SubscriptionClient
from azure.storage.blob import BlockBlobService
from memoization import cached from memoization import cached
from msrestazure.azure_active_directory import MSIAuthentication from msrestazure.azure_active_directory import MSIAuthentication
from msrestazure.tools import parse_resource_id from msrestazure.tools import parse_resource_id
from onefuzztypes.primitives import Container
from .monkeypatch import allow_more_workers, reduce_logging from .monkeypatch import allow_more_workers, reduce_logging
@ -35,34 +33,14 @@ def mgmt_client_factory(client_class: Any) -> Any:
try: try:
return get_client_from_cli_profile(client_class) return get_client_from_cli_profile(client_class)
except CLIError: except CLIError:
if issubclass(client_class, SubscriptionClient): pass
return client_class(get_msi()) except OSError:
else: pass
return client_class(get_msi(), get_subscription())
if issubclass(client_class, SubscriptionClient):
@cached return client_class(get_msi())
def get_storage_account_name_key(account_id: Optional[str] = None) -> Tuple[str, str]: else:
db_client = mgmt_client_factory(StorageManagementClient) return client_class(get_msi(), get_subscription())
if account_id is None:
account_id = os.environ["ONEFUZZ_DATA_STORAGE"]
resource = parse_resource_id(account_id)
key = (
db_client.storage_accounts.list_keys(
resource["resource_group"], resource["name"]
)
.keys[0]
.value
)
return resource["name"], key
@cached
def get_blob_service(account_id: Optional[str] = None) -> BlockBlobService:
logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
service = BlockBlobService(account_name=name, account_key=key)
return service
@cached @cached
@ -92,16 +70,6 @@ def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"] return os.environ["APPINSIGHTS_APPID"]
# @cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
# @cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached @cached
def get_instance_name() -> str: def get_instance_name() -> str:
return os.environ["ONEFUZZ_INSTANCE_NAME"] return os.environ["ONEFUZZ_INSTANCE_NAME"]
@ -114,9 +82,10 @@ def get_instance_url() -> str:
@cached @cached
def get_instance_id() -> UUID: def get_instance_id() -> UUID:
from .containers import StorageType, get_blob from .containers import get_blob
from .storage import StorageType
blob = get_blob("base-config", "instance_id", StorageType.config) blob = get_blob(Container("base-config"), "instance_id", StorageType.config)
if blob is None: if blob is None:
raise Exception("missing instance_id") raise Exception("missing instance_id")
return UUID(blob.decode()) return UUID(blob.decode())

View File

@ -19,8 +19,7 @@ from azure.storage.queue import (
from memoization import cached from memoization import cached
from pydantic import BaseModel from pydantic import BaseModel
from .containers import StorageType, get_account_id_by_type from .storage import StorageType, get_primary_account, get_storage_account_name_key
from .creds import get_storage_account_name_key
QueueNameType = Union[str, UUID] QueueNameType = Union[str, UUID]
@ -29,7 +28,7 @@ DEFAULT_TTL = -1
@cached(ttl=60) @cached(ttl=60)
def get_queue_client(storage_type: StorageType) -> QueueServiceClient: def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
account_id = get_account_id_by_type(storage_type) account_id = get_primary_account(storage_type)
logging.debug("getting blob container (account_id: %s)", account_id) logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id) name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name account_url = "https://%s.queue.core.windows.net" % name
@ -50,7 +49,7 @@ def get_queue_sas(
update: bool = False, update: bool = False,
process: bool = False, process: bool = False,
) -> str: ) -> str:
account_id = get_account_id_by_type(storage_type) account_id = get_primary_account(storage_type)
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id) logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id) name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30) expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)

View File

@ -0,0 +1,108 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import os
import random
from enum import Enum
from typing import List, Tuple
from azure.mgmt.storage import StorageManagementClient
from memoization import cached
from msrestazure.tools import parse_resource_id
from .creds import get_base_resource_group, mgmt_client_factory
class StorageType(Enum):
corpus = "corpus"
config = "config"
@cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@cached
def get_primary_account(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
# see #322 for discussion about typing
return get_fuzz_storage()
elif storage_type == StorageType.config:
# see #322 for discussion about typing
return get_func_storage()
raise NotImplementedError
@cached
def get_accounts(storage_type: StorageType) -> List[str]:
if storage_type == StorageType.corpus:
return corpus_accounts()
elif storage_type == StorageType.config:
return [get_func_storage()]
else:
raise NotImplementedError
@cached
def get_storage_account_name_key(account_id: str) -> Tuple[str, str]:
client = mgmt_client_factory(StorageManagementClient)
resource = parse_resource_id(account_id)
key = (
client.storage_accounts.list_keys(resource["resource_group"], resource["name"])
.keys[0]
.value
)
return resource["name"], key
def choose_account(storage_type: StorageType) -> str:
accounts = get_accounts(storage_type)
if not accounts:
raise Exception(f"no storage accounts for {storage_type}")
if len(accounts) == 1:
return accounts[0]
# Use a random secondary storage account if any are available. This
# reduces IOP contention for the Storage Queues, which are only available
# on primary accounts
return random.choice(accounts[1:])
@cached
def corpus_accounts() -> List[str]:
skip = get_func_storage()
results = [get_fuzz_storage()]
client = mgmt_client_factory(StorageManagementClient)
group = get_base_resource_group()
for account in client.storage_accounts.list_by_resource_group(group):
# protection from someone adding the corpus tag to the config account
if account.id == skip:
continue
if account.id in results:
continue
if account.primary_endpoints.blob is None:
continue
if (
"storage_type" not in account.tags
or account.tags["storage_type"] != "corpus"
):
continue
results.append(account.id)
logging.info("corpus accounts: %s", corpus_accounts)
return results

View File

@ -10,7 +10,7 @@ from typing import Optional
from azure.cosmosdb.table import TableService from azure.cosmosdb.table import TableService
from memoization import cached from memoization import cached
from .creds import get_storage_account_name_key from .storage import get_storage_account_name_key
@cached(ttl=60) @cached(ttl=60)

View File

@ -9,17 +9,13 @@ from uuid import UUID
from onefuzztypes.enums import OS, AgentMode from onefuzztypes.enums import OS, AgentMode
from onefuzztypes.models import AgentConfig, Pool, ReproConfig, Scaleset from onefuzztypes.models import AgentConfig, Pool, ReproConfig, Scaleset
from onefuzztypes.primitives import Extension, Region from onefuzztypes.primitives import Container, Extension, Region
from .azure.containers import ( from .azure.containers import get_container_sas_url, get_file_sas_url, save_blob
StorageType,
get_container_sas_url,
get_file_sas_url,
save_blob,
)
from .azure.creds import get_instance_id, get_instance_url from .azure.creds import get_instance_id, get_instance_url
from .azure.monitor import get_monitor_settings from .azure.monitor import get_monitor_settings
from .azure.queue import get_queue_sas from .azure.queue import get_queue_sas
from .azure.storage import StorageType
from .reports import get_report from .reports import get_report
@ -95,8 +91,12 @@ def build_scaleset_script(pool: Pool, scaleset: Scaleset) -> str:
ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys" ssh_path = "$env:ProgramData/ssh/administrators_authorized_keys"
commands += [f'Set-Content -Path {ssh_path} -Value "{ssh_key}"'] commands += [f'Set-Content -Path {ssh_path} -Value "{ssh_key}"']
save_blob("vm-scripts", filename, sep.join(commands) + sep, StorageType.config) save_blob(
return get_file_sas_url("vm-scripts", filename, StorageType.config, read=True) Container("vm-scripts"), filename, sep.join(commands) + sep, StorageType.config
)
return get_file_sas_url(
Container("vm-scripts"), filename, StorageType.config, read=True
)
def build_pool_config(pool: Pool) -> str: def build_pool_config(pool: Pool) -> str:
@ -116,14 +116,14 @@ def build_pool_config(pool: Pool) -> str:
filename = f"{pool.name}/config.json" filename = f"{pool.name}/config.json"
save_blob( save_blob(
"vm-scripts", Container("vm-scripts"),
filename, filename,
config.json(), config.json(),
StorageType.config, StorageType.config,
) )
return get_file_sas_url( return get_file_sas_url(
"vm-scripts", Container("vm-scripts"),
filename, filename,
StorageType.config, StorageType.config,
read=True, read=True,
@ -135,24 +135,28 @@ def update_managed_scripts() -> None:
"azcopy sync '%s' instance-specific-setup" "azcopy sync '%s' instance-specific-setup"
% ( % (
get_container_sas_url( get_container_sas_url(
"instance-specific-setup", Container("instance-specific-setup"),
StorageType.config, StorageType.config,
read=True, read=True,
list=True, list=True,
) )
), ),
"azcopy sync '%s' tools" "azcopy sync '%s' tools"
% (get_container_sas_url("tools", StorageType.config, read=True, list=True)), % (
get_container_sas_url(
Container("tools"), StorageType.config, read=True, list=True
)
),
] ]
save_blob( save_blob(
"vm-scripts", Container("vm-scripts"),
"managed.ps1", "managed.ps1",
"\r\n".join(commands) + "\r\n", "\r\n".join(commands) + "\r\n",
StorageType.config, StorageType.config,
) )
save_blob( save_blob(
"vm-scripts", Container("vm-scripts"),
"managed.sh", "managed.sh",
"\n".join(commands) + "\n", "\n".join(commands) + "\n",
StorageType.config, StorageType.config,
@ -170,25 +174,25 @@ def agent_config(
if vm_os == OS.windows: if vm_os == OS.windows:
urls += [ urls += [
get_file_sas_url( get_file_sas_url(
"vm-scripts", Container("vm-scripts"),
"managed.ps1", "managed.ps1",
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", Container("tools"),
"win64/azcopy.exe", "win64/azcopy.exe",
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", Container("tools"),
"win64/setup.ps1", "win64/setup.ps1",
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", Container("tools"),
"win64/onefuzz.ps1", "win64/onefuzz.ps1",
StorageType.config, StorageType.config,
read=True, read=True,
@ -212,19 +216,19 @@ def agent_config(
elif vm_os == OS.linux: elif vm_os == OS.linux:
urls += [ urls += [
get_file_sas_url( get_file_sas_url(
"vm-scripts", Container("vm-scripts"),
"managed.sh", "managed.sh",
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", Container("tools"),
"linux/azcopy", "linux/azcopy",
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", Container("tools"),
"linux/setup.sh", "linux/setup.sh",
StorageType.config, StorageType.config,
read=True, read=True,
@ -260,7 +264,7 @@ def repro_extensions(
repro_os: OS, repro_os: OS,
repro_id: UUID, repro_id: UUID,
repro_config: ReproConfig, repro_config: ReproConfig,
setup_container: Optional[str], setup_container: Optional[Container],
) -> List[Extension]: ) -> List[Extension]:
# TODO - what about contents of repro.ps1 / repro.sh? # TODO - what about contents of repro.ps1 / repro.sh?
report = get_report(repro_config.container, repro_config.path) report = get_report(repro_config.container, repro_config.path)
@ -302,7 +306,7 @@ def repro_extensions(
script_name = "task-setup.sh" script_name = "task-setup.sh"
save_blob( save_blob(
"task-configs", Container("task-configs"),
"%s/%s" % (repro_id, script_name), "%s/%s" % (repro_id, script_name),
task_script, task_script,
StorageType.config, StorageType.config,
@ -311,13 +315,13 @@ def repro_extensions(
for repro_file in repro_files: for repro_file in repro_files:
urls += [ urls += [
get_file_sas_url( get_file_sas_url(
"repro-scripts", Container("repro-scripts"),
repro_file, repro_file,
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"task-configs", Container("task-configs"),
"%s/%s" % (repro_id, script_name), "%s/%s" % (repro_id, script_name),
StorageType.config, StorageType.config,
read=True, read=True,
@ -333,13 +337,13 @@ def repro_extensions(
def proxy_manager_extensions(region: Region) -> List[Extension]: def proxy_manager_extensions(region: Region) -> List[Extension]:
urls = [ urls = [
get_file_sas_url( get_file_sas_url(
"proxy-configs", Container("proxy-configs"),
"%s/config.json" % region, "%s/config.json" % region,
StorageType.config, StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", Container("tools"),
"linux/onefuzz-proxy-manager", "linux/onefuzz-proxy-manager",
StorageType.config, StorageType.config,
read=True, read=True,

View File

@ -25,6 +25,7 @@ from azure.devops.v6_0.work_item_tracking.work_item_tracking_client import (
) )
from memoization import cached from memoization import cached
from onefuzztypes.models import ADOTemplate, Report from onefuzztypes.models import ADOTemplate, Report
from onefuzztypes.primitives import Container
from .common import Render, fail_task from .common import Render, fail_task
@ -49,7 +50,7 @@ def get_valid_fields(
class ADO: class ADO:
def __init__( def __init__(
self, container: str, filename: str, config: ADOTemplate, report: Report self, container: Container, filename: str, config: ADOTemplate, report: Report
): ):
self.config = config self.config = config
self.renderer = Render(container, filename, report) self.renderer = Render(container, filename, report)
@ -200,7 +201,7 @@ class ADO:
def notify_ado( def notify_ado(
config: ADOTemplate, container: str, filename: str, report: Report config: ADOTemplate, container: Container, filename: str, report: Report
) -> None: ) -> None:
logging.info( logging.info(
"notify ado: job_id:%s task_id:%s container:%s filename:%s", "notify ado: job_id:%s task_id:%s container:%s filename:%s",

View File

@ -9,6 +9,7 @@ from typing import Optional
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from onefuzztypes.enums import ErrorCode from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, Report from onefuzztypes.models import Error, Report
from onefuzztypes.primitives import Container
from ..azure.containers import auth_download_url from ..azure.containers import auth_download_url
from ..azure.creds import get_instance_url from ..azure.creds import get_instance_url
@ -33,7 +34,7 @@ def fail_task(report: Report, error: Exception) -> None:
class Render: class Render:
def __init__(self, container: str, filename: str, report: Report): def __init__(self, container: Container, filename: str, report: Report):
self.report = report self.report = report
self.container = container self.container = container
self.filename = filename self.filename = filename

View File

@ -11,13 +11,18 @@ from github3.exceptions import GitHubException
from github3.issues import Issue from github3.issues import Issue
from onefuzztypes.enums import GithubIssueSearchMatch from onefuzztypes.enums import GithubIssueSearchMatch
from onefuzztypes.models import GithubIssueTemplate, Report from onefuzztypes.models import GithubIssueTemplate, Report
from onefuzztypes.primitives import Container
from .common import Render, fail_task from .common import Render, fail_task
class GithubIssue: class GithubIssue:
def __init__( def __init__(
self, config: GithubIssueTemplate, container: str, filename: str, report: Report self,
config: GithubIssueTemplate,
container: Container,
filename: str,
report: Report,
): ):
self.config = config self.config = config
self.report = report self.report = report
@ -95,7 +100,10 @@ class GithubIssue:
def github_issue( def github_issue(
config: GithubIssueTemplate, container: str, filename: str, report: Optional[Report] config: GithubIssueTemplate,
container: Container,
filename: str,
report: Optional[Report],
) -> None: ) -> None:
if report is None: if report is None:
return return

View File

@ -21,12 +21,12 @@ from onefuzztypes.models import (
from onefuzztypes.primitives import Container, Event from onefuzztypes.primitives import Container, Event
from ..azure.containers import ( from ..azure.containers import (
StorageType,
container_exists, container_exists,
get_container_metadata, get_container_metadata,
get_file_sas_url, get_file_sas_url,
) )
from ..azure.queue import send_message from ..azure.queue import send_message
from ..azure.storage import StorageType
from ..dashboard import add_event from ..dashboard import add_event
from ..orm import ORMMixin from ..orm import ORMMixin
from ..reports import get_report from ..reports import get_report

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
import requests import requests
from onefuzztypes.models import Report, TeamsTemplate from onefuzztypes.models import Report, TeamsTemplate
from onefuzztypes.primitives import Container
from ..azure.containers import auth_download_url from ..azure.containers import auth_download_url
from ..tasks.config import get_setup_container from ..tasks.config import get_setup_container
@ -51,7 +52,7 @@ def send_teams_webhook(
def notify_teams( def notify_teams(
config: TeamsTemplate, container: str, filename: str, report: Optional[Report] config: TeamsTemplate, container: Container, filename: str, report: Optional[Report]
) -> None: ) -> None:
text = None text = None
facts: List[Dict[str, str]] = [] facts: List[Dict[str, str]] = []

View File

@ -36,7 +36,6 @@ from pydantic import BaseModel, Field
from .__version__ import __version__ from .__version__ import __version__
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.containers import StorageType
from .azure.image import get_os from .azure.image import get_os
from .azure.network import Network from .azure.network import Network
from .azure.queue import ( from .azure.queue import (
@ -47,6 +46,7 @@ from .azure.queue import (
queue_object, queue_object,
remove_first_message, remove_first_message,
) )
from .azure.storage import StorageType
from .azure.vmss import ( from .azure.vmss import (
UnableToUpdate, UnableToUpdate,
create_vmss, create_vmss,

View File

@ -16,14 +16,15 @@ from onefuzztypes.models import (
ProxyConfig, ProxyConfig,
ProxyHeartbeat, ProxyHeartbeat,
) )
from onefuzztypes.primitives import Region from onefuzztypes.primitives import Container, Region
from pydantic import Field from pydantic import Field
from .__version__ import __version__ from .__version__ import __version__
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.containers import StorageType, get_file_sas_url, save_blob from .azure.containers import get_file_sas_url, save_blob
from .azure.ip import get_public_ip from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas from .azure.queue import get_queue_sas
from .azure.storage import StorageType
from .azure.vm import VM from .azure.vm import VM
from .extension import proxy_manager_extensions from .extension import proxy_manager_extensions
from .orm import MappingIntStrAny, ORMMixin, QueryFilter from .orm import MappingIntStrAny, ORMMixin, QueryFilter
@ -188,7 +189,7 @@ class Proxy(ORMMixin):
forwards = self.get_forwards() forwards = self.get_forwards()
proxy_config = ProxyConfig( proxy_config = ProxyConfig(
url=get_file_sas_url( url=get_file_sas_url(
"proxy-configs", Container("proxy-configs"),
"%s/config.json" % self.region, "%s/config.json" % self.region,
StorageType.config, StorageType.config,
read=True, read=True,
@ -203,7 +204,7 @@ class Proxy(ORMMixin):
) )
save_blob( save_blob(
"proxy-configs", Container("proxy-configs"),
"%s/config.json" % self.region, "%s/config.json" % self.region,
proxy_config.json(), proxy_config.json(),
StorageType.config, StorageType.config,

View File

@ -8,9 +8,11 @@ import logging
from typing import Optional, Union from typing import Optional, Union
from onefuzztypes.models import Report from onefuzztypes.models import Report
from onefuzztypes.primitives import Container
from pydantic import ValidationError from pydantic import ValidationError
from .azure.containers import StorageType, get_blob from .azure.containers import get_blob
from .azure.storage import StorageType
def parse_report( def parse_report(
@ -44,7 +46,7 @@ def parse_report(
return entry return entry
def get_report(container: str, filename: str) -> Optional[Report]: def get_report(container: Container, filename: str) -> Optional[Report]:
metadata = "/".join([container, filename]) metadata = "/".join([container, filename])
if not filename.endswith(".json"): if not filename.endswith(".json"):
logging.error("get_report invalid extension: %s", metadata) logging.error("get_report invalid extension: %s", metadata)

View File

@ -12,11 +12,13 @@ from onefuzztypes.enums import OS, ContainerType, ErrorCode, VmState
from onefuzztypes.models import Error from onefuzztypes.models import Error
from onefuzztypes.models import Repro as BASE_REPRO from onefuzztypes.models import Repro as BASE_REPRO
from onefuzztypes.models import ReproConfig, TaskVm, UserInfo from onefuzztypes.models import ReproConfig, TaskVm, UserInfo
from onefuzztypes.primitives import Container
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.containers import StorageType, save_blob from .azure.containers import save_blob
from .azure.creds import get_base_region from .azure.creds import get_base_region
from .azure.ip import get_public_ip from .azure.ip import get_public_ip
from .azure.storage import StorageType
from .azure.vm import VM from .azure.vm import VM
from .extension import repro_extensions from .extension import repro_extensions
from .orm import ORMMixin, QueryFilter from .orm import ORMMixin, QueryFilter
@ -98,7 +100,7 @@ class Repro(BASE_REPRO, ORMMixin):
) )
return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors)) return self.set_error(Error(code=ErrorCode.VM_CREATE_FAILED, errors=errors))
def get_setup_container(self) -> Optional[str]: def get_setup_container(self) -> Optional[Container]:
task = Task.get_by_task_id(self.task_id) task = Task.get_by_task_id(self.task_id)
if isinstance(task, Task): if isinstance(task, Task):
for container in task.config.containers: for container in task.config.containers:
@ -202,7 +204,7 @@ class Repro(BASE_REPRO, ORMMixin):
for filename in files: for filename in files:
save_blob( save_blob(
"repro-scripts", Container("repro-scripts"),
"%s/%s" % (self.vm_id, filename), "%s/%s" % (self.vm_id, filename),
files[filename], files[filename],
StorageType.config, StorageType.config,

View File

@ -10,15 +10,12 @@ from uuid import UUID
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig
from onefuzztypes.primitives import Container
from ..azure.containers import ( from ..azure.containers import blob_exists, container_exists, get_container_sas_url
StorageType,
blob_exists,
container_exists,
get_container_sas_url,
)
from ..azure.creds import get_instance_id, get_instance_url from ..azure.creds import get_instance_id, get_instance_url
from ..azure.queue import get_queue_sas from ..azure.queue import get_queue_sas
from ..azure.storage import StorageType
from .defs import TASK_DEFINITIONS from .defs import TASK_DEFINITIONS
LOGGER = logging.getLogger("onefuzz") LOGGER = logging.getLogger("onefuzz")
@ -334,7 +331,7 @@ def build_task_config(
return config return config
def get_setup_container(config: TaskConfig) -> str: def get_setup_container(config: TaskConfig) -> Container:
for container in config.containers: for container in config.containers:
if container.type == ContainerType.setup: if container.type == ContainerType.setup:
return container.name return container.name

View File

@ -18,9 +18,9 @@ from onefuzztypes.webhooks import (
WebhookEventTaskStopped, WebhookEventTaskStopped,
) )
from ..azure.containers import StorageType
from ..azure.image import get_os from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue from ..azure.queue import create_queue, delete_queue
from ..azure.storage import StorageType
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
from ..pools import Node, Pool, Scaleset from ..pools import Node, Pool, Scaleset
from ..proxy_forward import ProxyForward from ..proxy_forward import ProxyForward

View File

@ -11,7 +11,8 @@ from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit from onefuzztypes.models import WorkSet, WorkUnit
from pydantic import BaseModel from pydantic import BaseModel
from ..azure.containers import StorageType, blob_exists, get_container_sas_url from ..azure.containers import blob_exists, get_container_sas_url
from ..azure.storage import StorageType
from ..pools import Pool from ..pools import Pool
from .config import build_task_config, get_setup_container from .config import build_task_config, get_setup_container
from .main import Task from .main import Task

View File

@ -10,8 +10,8 @@ from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import UpdateType from onefuzztypes.enums import UpdateType
from pydantic import BaseModel from pydantic import BaseModel
from .azure.containers import StorageType
from .azure.queue import queue_object from .azure.queue import queue_object
from .azure.storage import StorageType
# This class isn't intended to be shared outside of the service # This class isn't intended to be shared outside of the service

View File

@ -27,8 +27,8 @@ from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
from pydantic import BaseModel from pydantic import BaseModel
from .__version__ import __version__ from .__version__ import __version__
from .azure.containers import StorageType
from .azure.queue import queue_object from .azure.queue import queue_object
from .azure.storage import StorageType
from .orm import ORMMixin from .orm import ORMMixin
MAX_TRIES = 5 MAX_TRIES = 5

View File

@ -12,7 +12,6 @@ from onefuzztypes.models import AgentConfig, Error
from onefuzztypes.requests import PoolCreate, PoolSearch, PoolStop from onefuzztypes.requests import PoolCreate, PoolSearch, PoolStop
from onefuzztypes.responses import BoolResult from onefuzztypes.responses import BoolResult
from ..onefuzzlib.azure.containers import StorageType
from ..onefuzzlib.azure.creds import ( from ..onefuzzlib.azure.creds import (
get_base_region, get_base_region,
get_instance_id, get_instance_id,
@ -20,6 +19,7 @@ from ..onefuzzlib.azure.creds import (
get_regions, get_regions,
) )
from ..onefuzzlib.azure.queue import get_queue_sas from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.azure.storage import StorageType
from ..onefuzzlib.azure.vmss import list_available_skus from ..onefuzzlib.azure.vmss import list_available_skus
from ..onefuzzlib.endpoint_authorization import call_if_user from ..onefuzzlib.endpoint_authorization import call_if_user
from ..onefuzzlib.pools import Pool from ..onefuzzlib.pools import Pool

View File

@ -9,7 +9,7 @@ from typing import Dict
import azure.functions as func import azure.functions as func
from ..onefuzzlib.azure.creds import get_fuzz_storage from ..onefuzzlib.azure.storage import corpus_accounts
from ..onefuzzlib.dashboard import get_event from ..onefuzzlib.dashboard import get_event
from ..onefuzzlib.notifications.main import new_files from ..onefuzzlib.notifications.main import new_files
@ -25,7 +25,7 @@ def file_added(event: Dict) -> None:
def main(msg: func.QueueMessage, dashboard: func.Out[str]) -> None: def main(msg: func.QueueMessage, dashboard: func.Out[str]) -> None:
event = json.loads(msg.get_body()) event = json.loads(msg.get_body())
if event["topic"] != get_fuzz_storage(): if event["topic"] in corpus_accounts():
return return
if event["eventType"] != "Microsoft.Storage.BlobCreated": if event["eventType"] != "Microsoft.Storage.BlobCreated":

View File

@ -580,6 +580,14 @@
"[resourceId('Microsoft.Storage/storageAccounts', variables('storageAccountName'))]" "[resourceId('Microsoft.Storage/storageAccounts', variables('storageAccountName'))]"
] ]
}, },
{
"type": "Microsoft.Storage/storageAccounts/blobServices/containers",
"apiVersion": "2018-03-01-preview",
"name": "[concat(variables('storageAccountNameFunc'), '/default/', 'vm-scripts')]",
"dependsOn": [
"[resourceId('Microsoft.Storage/storageAccounts', variables('storageAccountNameFunc'))]"
]
},
{ {
"type": "Microsoft.Storage/storageAccounts/blobServices/containers", "type": "Microsoft.Storage/storageAccounts/blobServices/containers",
"apiVersion": "2018-03-01-preview", "apiVersion": "2018-03-01-preview",

View File

@ -0,0 +1,106 @@
#!/usr/bin/env python
#
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import json
import uuid
from azure.common.client_factory import get_client_from_cli_profile
from azure.mgmt.eventgrid import EventGridManagementClient
from azure.mgmt.eventgrid.models import EventSubscription
from azure.mgmt.storage import StorageManagementClient
from azure.mgmt.storage.models import (
AccessTier,
Kind,
Sku,
SkuName,
StorageAccountCreateParameters,
)
# This was generated randomly and should be preserved moving forwards
STORAGE_GUID_NAMESPACE = uuid.UUID("f7eb528c-d849-4b81-9046-e7036f6203df")
def get_base_event(
client: EventGridManagementClient, resource_group: str, location: str
) -> EventSubscription:
for entry in client.event_subscriptions.list_regional_by_resource_group(
resource_group, location
):
if (
entry.name == "onefuzz1"
and entry.type == "Microsoft.EventGrid/eventSubscriptions"
and entry.event_delivery_schema == "EventGridSchema"
and entry.destination.endpoint_type == "StorageQueue"
and entry.destination.queue_name == "file-changes"
):
return entry
raise Exception("unable to find base eventgrid subscription")
def add_event_grid(src_account_id: str, resource_group: str, location: str) -> None:
client = get_client_from_cli_profile(EventGridManagementClient)
base = get_base_event(client, resource_group, location)
event_subscription_info = EventSubscription(
destination=base.destination,
filter=base.filter,
retry_policy=base.retry_policy,
)
topic_id = uuid.uuid5(STORAGE_GUID_NAMESPACE, src_account_id).hex
result = client.event_subscriptions.create_or_update(
src_account_id, "corpus" + topic_id, event_subscription_info
).result()
if result.provisioning_state != "Succeeded":
raise Exception(
"eventgrid subscription failed: %s"
% json.dumps(result.as_dict(), indent=4, sort_keys=True),
)
def create_storage(resource_group: str, account_name: str, location: str) -> str:
params = StorageAccountCreateParameters(
sku=Sku(name=SkuName.premium_lrs),
kind=Kind.block_blob_storage,
location=location,
tags={"storage_type": "corpus"},
access_tier=AccessTier.hot,
allow_blob_public_access=False,
minimum_tls_version="TLS1_2",
)
client = get_client_from_cli_profile(StorageManagementClient)
account = client.storage_accounts.create(
resource_group, account_name, params
).result()
if account.provisioning_state != "Succeeded":
raise Exception(
"storage account creation failed: %s",
json.dumps(account.as_dict(), indent=4, sort_keys=True),
)
return account.id
def create(resource_group: str, account_name: str, location: str) -> None:
new_account_id = create_storage(resource_group, account_name, location)
add_event_grid(new_account_id, resource_group, location)
def main():
formatter = argparse.ArgumentDefaultsHelpFormatter
parser = argparse.ArgumentParser(formatter_class=formatter)
parser.add_argument("resource_group")
parser.add_argument("account_name")
parser.add_argument("location")
args = parser.parse_args()
create(args.resource_group, args.account_name, args.location)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,6 @@
flake8
mypy
pytest
isort
vulture
black

View File

@ -0,0 +1,3 @@
azure-mgmt-storage~=11.2.0
azure-cli-core==2.13.0
azure-mgmt-eventgrid==3.0.0rc7