Use Storage Account types, rather than account_id (#320)

We need to move to supporting data sharding.

One of the steps towards that is stop passing around `account_id`, rather we need to specify the type of storage we need.
This commit is contained in:
bmc-msft 2020-11-18 09:06:14 -05:00 committed by GitHub
parent 52eca33237
commit e47e89609a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 205 additions and 139 deletions

View File

@ -13,7 +13,8 @@ from onefuzztypes.requests import AgentRegistrationGet, AgentRegistrationPost
from onefuzztypes.responses import AgentRegistration from onefuzztypes.responses import AgentRegistration
from ..onefuzzlib.agent_authorization import call_if_agent from ..onefuzzlib.agent_authorization import call_if_agent
from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_url from ..onefuzzlib.azure.containers import StorageType
from ..onefuzzlib.azure.creds import get_instance_url
from ..onefuzzlib.azure.queue import get_queue_sas from ..onefuzzlib.azure.queue import get_queue_sas
from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool
from ..onefuzzlib.request import not_ok, ok, parse_uri from ..onefuzzlib.request import not_ok, ok, parse_uri
@ -25,7 +26,7 @@ def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpRespo
commands_url = "%s/api/agents/commands" % base_address commands_url = "%s/api/agents/commands" % base_address
work_queue = get_queue_sas( work_queue = get_queue_sas(
pool.get_pool_queue(), pool.get_pool_queue(),
account_id=get_fuzz_storage(), StorageType.corpus,
read=True, read=True,
update=True, update=True,
process=True, process=True,

View File

@ -13,6 +13,7 @@ from onefuzztypes.requests import ContainerCreate, ContainerDelete, ContainerGet
from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase
from ..onefuzzlib.azure.containers import ( from ..onefuzzlib.azure.containers import (
StorageType,
create_container, create_container,
delete_container, delete_container,
get_container_metadata, get_container_metadata,
@ -30,7 +31,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(request, Error): if isinstance(request, Error):
return not_ok(request, context="container get") return not_ok(request, context="container get")
if request is not None: if request is not None:
metadata = get_container_metadata(request.name) metadata = get_container_metadata(request.name, StorageType.corpus)
if metadata is None: if metadata is None:
return not_ok( return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]), Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
@ -41,6 +42,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
name=request.name, name=request.name,
sas_url=get_container_sas_url( sas_url=get_container_sas_url(
request.name, request.name,
StorageType.corpus,
read=True, read=True,
write=True, write=True,
create=True, create=True,
@ -51,7 +53,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
) )
return ok(info) return ok(info)
containers = get_containers() containers = get_containers(StorageType.corpus)
container_info = [] container_info = []
for name in containers: for name in containers:
@ -66,7 +68,7 @@ def post(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(request, context="container create") return not_ok(request, context="container create")
logging.info("container - creating %s", request.name) logging.info("container - creating %s", request.name)
sas = create_container(request.name, metadata=request.metadata) sas = create_container(request.name, StorageType.corpus, metadata=request.metadata)
if sas: if sas:
return ok( return ok(
ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata) ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata)
@ -83,7 +85,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse:
return not_ok(request, context="container delete") return not_ok(request, context="container delete")
logging.info("container - deleting %s", request.name) logging.info("container - deleting %s", request.name)
return ok(BoolResult(result=delete_container(request.name))) return ok(BoolResult(result=delete_container(request.name, StorageType.corpus)))
def main(req: func.HttpRequest) -> func.HttpResponse: def main(req: func.HttpRequest) -> func.HttpResponse:

View File

@ -8,6 +8,7 @@ from onefuzztypes.enums import ErrorCode
from onefuzztypes.models import Error, FileEntry from onefuzztypes.models import Error, FileEntry
from ..onefuzzlib.azure.containers import ( from ..onefuzzlib.azure.containers import (
StorageType,
blob_exists, blob_exists,
container_exists, container_exists,
get_file_sas_url, get_file_sas_url,
@ -20,13 +21,13 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
if isinstance(request, Error): if isinstance(request, Error):
return not_ok(request, context="download") return not_ok(request, context="download")
if not container_exists(request.container): if not container_exists(request.container, StorageType.corpus):
return not_ok( return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]), Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]),
context=request.container, context=request.container,
) )
if not blob_exists(request.container, request.filename): if not blob_exists(request.container, request.filename, StorageType.corpus):
return not_ok( return not_ok(
Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]), Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]),
context=request.filename, context=request.filename,
@ -34,7 +35,12 @@ def get(req: func.HttpRequest) -> func.HttpResponse:
return redirect( return redirect(
get_file_sas_url( get_file_sas_url(
request.container, request.filename, read=True, days=0, minutes=5 request.container,
request.filename,
StorageType.corpus,
read=True,
days=0,
minutes=5,
) )
) )

View File

@ -6,37 +6,61 @@
import datetime import datetime
import os import os
import urllib.parse import urllib.parse
from typing import Dict, Optional, Union, cast from enum import Enum
from typing import Any, Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions from azure.storage.blob import BlobPermissions, ContainerPermissions
from memoization import cached from memoization import cached
from .creds import get_blob_service from .creds import get_blob_service, get_func_storage, get_fuzz_storage
class StorageType(Enum):
corpus = "corpus"
config = "config"
def get_account_id_by_type(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
account_id = get_fuzz_storage()
elif storage_type == StorageType.config:
account_id = get_func_storage()
else:
raise NotImplementedError
return account_id
@cached(ttl=5) @cached(ttl=5)
def container_exists(name: str, account_id: Optional[str] = None) -> bool: def get_blob_service_by_type(storage_type: StorageType) -> Any:
account_id = get_account_id_by_type(storage_type)
return get_blob_service(account_id)
@cached(ttl=5)
def container_exists(name: str, storage_type: StorageType) -> bool:
try: try:
get_blob_service(account_id).get_container_properties(name) get_blob_service_by_type(storage_type).get_container_properties(name)
return True return True
except AzureHttpError: except AzureHttpError:
return False return False
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]: def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
return { return {
x.name: x.metadata x.name: x.metadata
for x in get_blob_service(account_id).list_containers(include_metadata=True) for x in get_blob_service_by_type(storage_type).list_containers(
include_metadata=True
)
if not x.name.startswith("$") if not x.name.startswith("$")
} }
def get_container_metadata( def get_container_metadata(
name: str, account_id: Optional[str] = None name: str, storage_type: StorageType
) -> Optional[Dict[str, str]]: ) -> Optional[Dict[str, str]]:
try: try:
result = get_blob_service(account_id).get_container_metadata(name) result = get_blob_service_by_type(storage_type).get_container_metadata(name)
return cast(Dict[str, str], result) return cast(Dict[str, str], result)
except AzureHttpError: except AzureHttpError:
pass pass
@ -44,22 +68,29 @@ def get_container_metadata(
def create_container( def create_container(
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]]
) -> Optional[str]: ) -> Optional[str]:
try: try:
get_blob_service(account_id).create_container(name, metadata=metadata) get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
except AzureHttpError: except AzureHttpError:
# azure storage already logs errors # azure storage already logs errors
return None return None
return get_container_sas_url( return get_container_sas_url(
name, read=True, add=True, create=True, write=True, delete=True, list=True name,
storage_type,
read=True,
add=True,
create=True,
write=True,
delete=True,
list=True,
) )
def delete_container(name: str, account_id: Optional[str] = None) -> bool: def delete_container(name: str, storage_type: StorageType) -> bool:
try: try:
return bool(get_blob_service(account_id).delete_container(name)) return bool(get_blob_service_by_type(storage_type).delete_container(name))
except AzureHttpError: except AzureHttpError:
# azure storage already logs errors # azure storage already logs errors
return False return False
@ -67,7 +98,8 @@ def delete_container(name: str, account_id: Optional[str] = None) -> bool:
def get_container_sas_url( def get_container_sas_url(
container: str, container: str,
account_id: Optional[str] = None, storage_type: StorageType,
*,
read: bool = False, read: bool = False,
add: bool = False, add: bool = False,
create: bool = False, create: bool = False,
@ -75,7 +107,7 @@ def get_container_sas_url(
delete: bool = False, delete: bool = False,
list: bool = False, list: bool = False,
) -> str: ) -> str:
service = get_blob_service(account_id) service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30) expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list) permission = ContainerPermissions(read, add, create, write, delete, list)
@ -91,7 +123,8 @@ def get_container_sas_url(
def get_file_sas_url( def get_file_sas_url(
container: str, container: str,
name: str, name: str,
account_id: Optional[str] = None, storage_type: StorageType,
*,
read: bool = False, read: bool = False,
add: bool = False, add: bool = False,
create: bool = False, create: bool = False,
@ -102,7 +135,7 @@ def get_file_sas_url(
hours: int = 0, hours: int = 0,
minutes: int = 0, minutes: int = 0,
) -> str: ) -> str:
service = get_blob_service(account_id) service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta( expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes days=days, hours=hours, minutes=minutes
) )
@ -117,9 +150,9 @@ def get_file_sas_url(
def save_blob( def save_blob(
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None container: str, name: str, data: Union[str, bytes], storage_type: StorageType
) -> None: ) -> None:
service = get_blob_service(account_id) service = get_blob_service_by_type(storage_type)
service.create_container(container) service.create_container(container)
if isinstance(data, str): if isinstance(data, str):
service.create_blob_from_text(container, name, data) service.create_blob_from_text(container, name, data)
@ -127,10 +160,8 @@ def save_blob(
service.create_blob_from_bytes(container, name, data) service.create_blob_from_bytes(container, name, data)
def get_blob( def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]:
container: str, name: str, account_id: Optional[str] = None service = get_blob_service_by_type(storage_type)
) -> Optional[bytes]:
service = get_blob_service(account_id)
try: try:
blob = service.get_blob_to_bytes(container, name).content blob = service.get_blob_to_bytes(container, name).content
return cast(bytes, blob) return cast(bytes, blob)
@ -138,8 +169,8 @@ def get_blob(
return None return None
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool: def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service(account_id) service = get_blob_service_by_type(storage_type)
try: try:
service.get_blob_properties(container, name) service.get_blob_properties(container, name)
return True return True
@ -147,8 +178,8 @@ def blob_exists(container: str, name: str, account_id: Optional[str] = None) ->
return False return False
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool: def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service(account_id) service = get_blob_service_by_type(storage_type)
try: try:
service.delete_blob(container, name) service.delete_blob(container, name)
return True return True

View File

@ -87,12 +87,12 @@ def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"] return os.environ["APPINSIGHTS_APPID"]
@cached # @cached
def get_fuzz_storage() -> str: def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"] return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached # @cached
def get_func_storage() -> str: def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"] return os.environ["ONEFUZZ_FUNC_STORAGE"]
@ -109,9 +109,9 @@ def get_instance_url() -> str:
@cached @cached
def get_instance_id() -> UUID: def get_instance_id() -> UUID:
from .containers import get_blob from .containers import StorageType, get_blob
blob = get_blob("base-config", "instance_id", account_id=get_func_storage()) blob = get_blob("base-config", "instance_id", StorageType.config)
if blob is None: if blob is None:
raise Exception("missing instance_id") raise Exception("missing instance_id")
return UUID(blob.decode()) return UUID(blob.decode())

View File

@ -19,6 +19,7 @@ from azure.storage.queue import (
from memoization import cached from memoization import cached
from pydantic import BaseModel from pydantic import BaseModel
from .containers import StorageType, get_account_id_by_type
from .creds import get_storage_account_name_key from .creds import get_storage_account_name_key
QueueNameType = Union[str, UUID] QueueNameType = Union[str, UUID]
@ -27,7 +28,8 @@ DEFAULT_TTL = -1
@cached(ttl=60) @cached(ttl=60)
def get_queue_client(account_id: str) -> QueueServiceClient: def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
account_id = get_account_id_by_type(storage_type)
logging.debug("getting blob container (account_id: %s)", account_id) logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id) name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name account_url = "https://%s.queue.core.windows.net" % name
@ -41,13 +43,14 @@ def get_queue_client(account_id: str) -> QueueServiceClient:
@cached(ttl=60) @cached(ttl=60)
def get_queue_sas( def get_queue_sas(
queue: QueueNameType, queue: QueueNameType,
storage_type: StorageType,
*, *,
account_id: str,
read: bool = False, read: bool = False,
add: bool = False, add: bool = False,
update: bool = False, update: bool = False,
process: bool = False, process: bool = False,
) -> str: ) -> str:
account_id = get_account_id_by_type(storage_type)
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id) logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id) name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30) expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
@ -67,31 +70,33 @@ def get_queue_sas(
@cached(ttl=60) @cached(ttl=60)
def create_queue(name: QueueNameType, *, account_id: str) -> None: def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(account_id) client = get_queue_client(storage_type)
try: try:
client.create_queue(str(name)) client.create_queue(str(name))
except ResourceExistsError: except ResourceExistsError:
pass pass
def delete_queue(name: QueueNameType, *, account_id: str) -> None: def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(account_id) client = get_queue_client(storage_type)
queues = client.list_queues() queues = client.list_queues()
if str(name) in [x["name"] for x in queues]: if str(name) in [x["name"] for x in queues]:
client.delete_queue(name) client.delete_queue(name)
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]: def get_queue(
client = get_queue_client(account_id) name: QueueNameType, storage_type: StorageType
) -> Optional[QueueServiceClient]:
client = get_queue_client(storage_type)
try: try:
return client.get_queue_client(str(name)) return client.get_queue_client(str(name))
except ResourceNotFoundError: except ResourceNotFoundError:
return None return None
def clear_queue(name: QueueNameType, *, account_id: str) -> None: def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
queue = get_queue(name, account_id=account_id) queue = get_queue(name, storage_type)
if queue: if queue:
try: try:
queue.clear_messages() queue.clear_messages()
@ -102,12 +107,12 @@ def clear_queue(name: QueueNameType, *, account_id: str) -> None:
def send_message( def send_message(
name: QueueNameType, name: QueueNameType,
message: bytes, message: bytes,
storage_type: StorageType,
*, *,
account_id: str,
visibility_timeout: Optional[int] = None, visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL, time_to_live: int = DEFAULT_TTL,
) -> None: ) -> None:
queue = get_queue(name, account_id=account_id) queue = get_queue(name, storage_type)
if queue: if queue:
try: try:
queue.send_message( queue.send_message(
@ -119,9 +124,8 @@ def send_message(
pass pass
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool: def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
create_queue(name, account_id=account_id) queue = get_queue(name, storage_type)
queue = get_queue(name, account_id=account_id)
if queue: if queue:
try: try:
for message in queue.receive_messages(): for message in queue.receive_messages():
@ -143,8 +147,8 @@ MAX_PEEK_SIZE = 32
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient # https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
def peek_queue( def peek_queue(
name: QueueNameType, name: QueueNameType,
storage_type: StorageType,
*, *,
account_id: str,
object_type: Type[A], object_type: Type[A],
max_messages: int = MAX_PEEK_SIZE, max_messages: int = MAX_PEEK_SIZE,
) -> List[A]: ) -> List[A]:
@ -154,7 +158,7 @@ def peek_queue(
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE: if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
raise ValueError("invalid max messages: %s" % max_messages) raise ValueError("invalid max messages: %s" % max_messages)
queue = get_queue(name, account_id=account_id) queue = get_queue(name, storage_type)
if not queue: if not queue:
return result return result
@ -168,12 +172,12 @@ def peek_queue(
def queue_object( def queue_object(
name: QueueNameType, name: QueueNameType,
message: BaseModel, message: BaseModel,
storage_type: StorageType,
*, *,
account_id: str,
visibility_timeout: Optional[int] = None, visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL, time_to_live: int = DEFAULT_TTL,
) -> bool: ) -> bool:
queue = get_queue(name, account_id=account_id) queue = get_queue(name, storage_type)
if not queue: if not queue:
raise Exception("unable to queue object, no such queue: %s" % queue) raise Exception("unable to queue object, no such queue: %s" % queue)

View File

@ -11,8 +11,13 @@ from onefuzztypes.enums import OS, AgentMode
from onefuzztypes.models import AgentConfig, ReproConfig from onefuzztypes.models import AgentConfig, ReproConfig
from onefuzztypes.primitives import Extension, Region from onefuzztypes.primitives import Extension, Region
from .azure.containers import get_container_sas_url, get_file_sas_url, save_blob from .azure.containers import (
from .azure.creds import get_func_storage, get_instance_id, get_instance_url StorageType,
get_container_sas_url,
get_file_sas_url,
save_blob,
)
from .azure.creds import get_instance_id, get_instance_url
from .azure.monitor import get_monitor_settings from .azure.monitor import get_monitor_settings
from .azure.queue import get_queue_sas from .azure.queue import get_queue_sas
from .reports import get_report from .reports import get_report
@ -96,7 +101,7 @@ def build_pool_config(pool_name: str) -> str:
instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"), instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"),
heartbeat_queue=get_queue_sas( heartbeat_queue=get_queue_sas(
"node-heartbeat", "node-heartbeat",
account_id=get_func_storage(), StorageType.config,
add=True, add=True,
), ),
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
@ -107,13 +112,13 @@ def build_pool_config(pool_name: str) -> str:
"vm-scripts", "vm-scripts",
"%s/config.json" % pool_name, "%s/config.json" % pool_name,
config.json(), config.json(),
account_id=get_func_storage(), StorageType.config,
) )
return get_file_sas_url( return get_file_sas_url(
"vm-scripts", "vm-scripts",
"%s/config.json" % pool_name, "%s/config.json" % pool_name,
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
) )
@ -124,30 +129,26 @@ def update_managed_scripts() -> None:
% ( % (
get_container_sas_url( get_container_sas_url(
"instance-specific-setup", "instance-specific-setup",
StorageType.config,
read=True, read=True,
list=True, list=True,
account_id=get_func_storage(),
) )
), ),
"azcopy sync '%s' tools" "azcopy sync '%s' tools"
% ( % (get_container_sas_url("tools", StorageType.config, read=True, list=True)),
get_container_sas_url(
"tools", read=True, list=True, account_id=get_func_storage()
)
),
] ]
save_blob( save_blob(
"vm-scripts", "vm-scripts",
"managed.ps1", "managed.ps1",
"\r\n".join(commands) + "\r\n", "\r\n".join(commands) + "\r\n",
account_id=get_func_storage(), StorageType.config,
) )
save_blob( save_blob(
"vm-scripts", "vm-scripts",
"managed.sh", "managed.sh",
"\n".join(commands) + "\n", "\n".join(commands) + "\n",
account_id=get_func_storage(), StorageType.config,
) )
@ -164,25 +165,25 @@ def agent_config(
get_file_sas_url( get_file_sas_url(
"vm-scripts", "vm-scripts",
"managed.ps1", "managed.ps1",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", "tools",
"win64/azcopy.exe", "win64/azcopy.exe",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", "tools",
"win64/setup.ps1", "win64/setup.ps1",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", "tools",
"win64/onefuzz.ps1", "win64/onefuzz.ps1",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
] ]
@ -206,19 +207,19 @@ def agent_config(
get_file_sas_url( get_file_sas_url(
"vm-scripts", "vm-scripts",
"managed.sh", "managed.sh",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", "tools",
"linux/azcopy", "linux/azcopy",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", "tools",
"linux/setup.sh", "linux/setup.sh",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
] ]
@ -263,13 +264,22 @@ def repro_extensions(
if setup_container: if setup_container:
commands += [ commands += [
"azcopy sync '%s' ./setup" "azcopy sync '%s' ./setup"
% (get_container_sas_url(setup_container, read=True, list=True)), % (
get_container_sas_url(
setup_container, StorageType.corpus, read=True, list=True
)
),
] ]
urls = [ urls = [
get_file_sas_url(repro_config.container, repro_config.path, read=True),
get_file_sas_url( get_file_sas_url(
report.input_blob.container, report.input_blob.name, read=True repro_config.container, repro_config.path, StorageType.corpus, read=True
),
get_file_sas_url(
report.input_blob.container,
report.input_blob.name,
StorageType.corpus,
read=True,
), ),
] ]
@ -288,7 +298,7 @@ def repro_extensions(
"task-configs", "task-configs",
"%s/%s" % (repro_id, script_name), "%s/%s" % (repro_id, script_name),
task_script, task_script,
account_id=get_func_storage(), StorageType.config,
) )
for repro_file in repro_files: for repro_file in repro_files:
@ -296,13 +306,13 @@ def repro_extensions(
get_file_sas_url( get_file_sas_url(
"repro-scripts", "repro-scripts",
repro_file, repro_file,
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"task-configs", "task-configs",
"%s/%s" % (repro_id, script_name), "%s/%s" % (repro_id, script_name),
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
] ]
@ -318,13 +328,13 @@ def proxy_manager_extensions(region: Region) -> List[Extension]:
get_file_sas_url( get_file_sas_url(
"proxy-configs", "proxy-configs",
"%s/config.json" % region, "%s/config.json" % region,
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
get_file_sas_url( get_file_sas_url(
"tools", "tools",
"linux/onefuzz-proxy-manager", "linux/onefuzz-proxy-manager",
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
] ]

View File

@ -21,11 +21,11 @@ from onefuzztypes.models import (
from onefuzztypes.primitives import Container, Event from onefuzztypes.primitives import Container, Event
from ..azure.containers import ( from ..azure.containers import (
StorageType,
container_exists, container_exists,
get_container_metadata, get_container_metadata,
get_file_sas_url, get_file_sas_url,
) )
from ..azure.creds import get_fuzz_storage
from ..azure.queue import send_message from ..azure.queue import send_message
from ..dashboard import add_event from ..dashboard import add_event
from ..orm import ORMMixin from ..orm import ORMMixin
@ -72,7 +72,7 @@ class Notification(models.Notification, ORMMixin):
def create( def create(
cls, container: Container, config: NotificationTemplate cls, container: Container, config: NotificationTemplate
) -> Result["Notification"]: ) -> Result["Notification"]:
if not container_exists(container): if not container_exists(container, StorageType.corpus):
return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]) return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"])
existing = cls.get_existing(container, config) existing = cls.get_existing(container, config)
@ -106,7 +106,7 @@ def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]:
@cached(ttl=60) @cached(ttl=60)
def container_metadata(container: Container) -> Optional[Dict[str, str]]: def container_metadata(container: Container) -> Optional[Dict[str, str]]:
return get_container_metadata(container) return get_container_metadata(container, StorageType.corpus)
def new_files(container: Container, filename: str) -> None: def new_files(container: Container, filename: str) -> None:
@ -149,9 +149,9 @@ def new_files(container: Container, filename: str) -> None:
for (task, containers) in get_queue_tasks(): for (task, containers) in get_queue_tasks():
if container in containers: if container in containers:
logging.info("queuing input %s %s %s", container, filename, task.task_id) logging.info("queuing input %s %s %s", container, filename, task.task_id)
url = get_file_sas_url(container, filename, read=True, delete=True) url = get_file_sas_url(
send_message( container, filename, StorageType.corpus, read=True, delete=True
task.task_id, bytes(url, "utf-8"), account_id=get_fuzz_storage()
) )
send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus)
add_event("new_file", results) add_event("new_file", results)

View File

@ -35,7 +35,7 @@ from pydantic import BaseModel, Field
from .__version__ import __version__ from .__version__ import __version__
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.creds import get_func_storage, get_fuzz_storage from .azure.containers import StorageType
from .azure.image import get_os from .azure.image import get_os
from .azure.network import Network from .azure.network import Network
from .azure.queue import ( from .azure.queue import (
@ -442,7 +442,7 @@ class Pool(BASE_POOL, ORMMixin):
return return
worksets = peek_queue( worksets = peek_queue(
self.get_pool_queue(), account_id=get_fuzz_storage(), object_type=WorkSet self.get_pool_queue(), StorageType.corpus, object_type=WorkSet
) )
for workset in worksets: for workset in worksets:
@ -460,7 +460,7 @@ class Pool(BASE_POOL, ORMMixin):
return "pool-%s" % self.pool_id.hex return "pool-%s" % self.pool_id.hex
def init(self) -> None: def init(self) -> None:
create_queue(self.get_pool_queue(), account_id=get_fuzz_storage()) create_queue(self.get_pool_queue(), StorageType.corpus)
self.state = PoolState.running self.state = PoolState.running
self.save() self.save()
@ -470,7 +470,9 @@ class Pool(BASE_POOL, ORMMixin):
return False return False
return queue_object( return queue_object(
self.get_pool_queue(), work_set, account_id=get_fuzz_storage() self.get_pool_queue(),
work_set,
StorageType.corpus,
) )
@classmethod @classmethod
@ -531,7 +533,7 @@ class Pool(BASE_POOL, ORMMixin):
scalesets = Scaleset.search_by_pool(self.name) scalesets = Scaleset.search_by_pool(self.name)
nodes = Node.search(query={"pool_name": [self.name]}) nodes = Node.search(query={"pool_name": [self.name]})
if not scalesets and not nodes: if not scalesets and not nodes:
delete_queue(self.get_pool_queue(), account_id=get_fuzz_storage()) delete_queue(self.get_pool_queue(), StorageType.corpus)
logging.info("pool stopped, deleting: %s", self.name) logging.info("pool stopped, deleting: %s", self.name)
self.state = PoolState.halt self.state = PoolState.halt
self.delete() self.delete()
@ -1053,16 +1055,16 @@ class ScalesetShrinkQueue:
return "to-shrink-%s" % self.scaleset_id.hex return "to-shrink-%s" % self.scaleset_id.hex
def clear(self) -> None: def clear(self) -> None:
clear_queue(self.queue_name(), account_id=get_func_storage()) clear_queue(self.queue_name(), StorageType.config)
def create(self) -> None: def create(self) -> None:
create_queue(self.queue_name(), account_id=get_func_storage()) create_queue(self.queue_name(), StorageType.config)
def delete(self) -> None: def delete(self) -> None:
delete_queue(self.queue_name(), account_id=get_func_storage()) delete_queue(self.queue_name(), StorageType.config)
def add_entry(self) -> None: def add_entry(self) -> None:
queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage()) queue_object(self.queue_name(), ShrinkEntry(), StorageType.config)
def should_shrink(self) -> bool: def should_shrink(self) -> bool:
return remove_first_message(self.queue_name(), account_id=get_func_storage()) return remove_first_message(self.queue_name(), StorageType.config)

View File

@ -21,8 +21,7 @@ from pydantic import Field
from .__version__ import __version__ from .__version__ import __version__
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.containers import get_file_sas_url, save_blob from .azure.containers import StorageType, get_file_sas_url, save_blob
from .azure.creds import get_func_storage
from .azure.ip import get_public_ip from .azure.ip import get_public_ip
from .azure.queue import get_queue_sas from .azure.queue import get_queue_sas
from .azure.vm import VM from .azure.vm import VM
@ -191,12 +190,12 @@ class Proxy(ORMMixin):
url=get_file_sas_url( url=get_file_sas_url(
"proxy-configs", "proxy-configs",
"%s/config.json" % self.region, "%s/config.json" % self.region,
account_id=get_func_storage(), StorageType.config,
read=True, read=True,
), ),
notification=get_queue_sas( notification=get_queue_sas(
"proxy", "proxy",
account_id=get_func_storage(), StorageType.config,
add=True, add=True,
), ),
forwards=forwards, forwards=forwards,
@ -207,7 +206,7 @@ class Proxy(ORMMixin):
"proxy-configs", "proxy-configs",
"%s/config.json" % self.region, "%s/config.json" % self.region,
proxy_config.json(), proxy_config.json(),
account_id=get_func_storage(), StorageType.config,
) )
@classmethod @classmethod

View File

@ -10,7 +10,7 @@ from typing import Optional, Union
from onefuzztypes.models import Report from onefuzztypes.models import Report
from pydantic import ValidationError from pydantic import ValidationError
from .azure.containers import get_blob from .azure.containers import StorageType, get_blob
def parse_report( def parse_report(
@ -50,7 +50,7 @@ def get_report(container: str, filename: str) -> Optional[Report]:
logging.error("get_report invalid extension: %s", metadata) logging.error("get_report invalid extension: %s", metadata)
return None return None
blob = get_blob(container, filename) blob = get_blob(container, filename, StorageType.corpus)
if blob is None: if blob is None:
logging.error("get_report invalid blob: %s", metadata) logging.error("get_report invalid blob: %s", metadata)
return None return None

View File

@ -14,8 +14,8 @@ from onefuzztypes.models import Repro as BASE_REPRO
from onefuzztypes.models import ReproConfig, TaskVm from onefuzztypes.models import ReproConfig, TaskVm
from .azure.auth import build_auth from .azure.auth import build_auth
from .azure.containers import save_blob from .azure.containers import StorageType, save_blob
from .azure.creds import get_base_region, get_func_storage from .azure.creds import get_base_region
from .azure.ip import get_public_ip from .azure.ip import get_public_ip
from .azure.vm import VM from .azure.vm import VM
from .extension import repro_extensions from .extension import repro_extensions
@ -205,7 +205,7 @@ class Repro(BASE_REPRO, ORMMixin):
"repro-scripts", "repro-scripts",
"%s/%s" % (self.vm_id, filename), "%s/%s" % (self.vm_id, filename),
files[filename], files[filename],
account_id=get_func_storage(), StorageType.config,
) )
logging.info("saved repro script") logging.info("saved repro script")

View File

@ -11,13 +11,13 @@ from uuid import UUID
from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature
from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig
from ..azure.containers import blob_exists, container_exists, get_container_sas_url from ..azure.containers import (
from ..azure.creds import ( StorageType,
get_func_storage, blob_exists,
get_fuzz_storage, container_exists,
get_instance_id, get_container_sas_url,
get_instance_url,
) )
from ..azure.creds import get_instance_id, get_instance_url
from ..azure.queue import get_queue_sas from ..azure.queue import get_queue_sas
from .defs import TASK_DEFINITIONS from .defs import TASK_DEFINITIONS
@ -68,7 +68,7 @@ def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
containers: Dict[ContainerType, List[str]] = {} containers: Dict[ContainerType, List[str]] = {}
for container in config.containers: for container in config.containers:
if container.name not in checked: if container.name not in checked:
if not container_exists(container.name): if not container_exists(container.name, StorageType.corpus):
raise TaskConfigError("missing container: %s" % container.name) raise TaskConfigError("missing container: %s" % container.name)
checked.add(container.name) checked.add(container.name)
@ -137,7 +137,7 @@ def check_config(config: TaskConfig) -> None:
if TaskFeature.target_exe in definition.features: if TaskFeature.target_exe in definition.features:
container = [x for x in config.containers if x.type == ContainerType.setup][0] container = [x for x in config.containers if x.type == ContainerType.setup][0]
if not blob_exists(container.name, config.task.target_exe): if not blob_exists(container.name, config.task.target_exe, StorageType.corpus):
err = "target_exe `%s` does not exist in the setup container `%s`" % ( err = "target_exe `%s` does not exist in the setup container `%s`" % (
config.task.target_exe, config.task.target_exe,
container.name, container.name,
@ -153,7 +153,7 @@ def check_config(config: TaskConfig) -> None:
for tool_path in tools_paths: for tool_path in tools_paths:
if config.task.generator_exe.startswith(tool_path): if config.task.generator_exe.startswith(tool_path):
generator = config.task.generator_exe.replace(tool_path, "") generator = config.task.generator_exe.replace(tool_path, "")
if not blob_exists(container.name, generator): if not blob_exists(container.name, generator, StorageType.corpus):
err = ( err = (
"generator_exe `%s` does not exist in the tools container `%s`" "generator_exe `%s` does not exist in the tools container `%s`"
% ( % (
@ -188,7 +188,7 @@ def build_task_config(
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
heartbeat_queue=get_queue_sas( heartbeat_queue=get_queue_sas(
"task-heartbeat", "task-heartbeat",
account_id=get_func_storage(), StorageType.config,
add=True, add=True,
), ),
back_channel_address="https://%s/api/back_channel" % (get_instance_url()), back_channel_address="https://%s/api/back_channel" % (get_instance_url()),
@ -198,11 +198,11 @@ def build_task_config(
if definition.monitor_queue: if definition.monitor_queue:
config.input_queue = get_queue_sas( config.input_queue = get_queue_sas(
task_id, task_id,
StorageType.corpus,
add=True, add=True,
read=True, read=True,
update=True, update=True,
process=True, process=True,
account_id=get_fuzz_storage(),
) )
for container_def in definition.containers: for container_def in definition.containers:
@ -219,6 +219,7 @@ def build_task_config(
"path": "_".join(["task", container_def.type.name, str(i)]), "path": "_".join(["task", container_def.type.name, str(i)]),
"url": get_container_sas_url( "url": get_container_sas_url(
container.name, container.name,
StorageType.corpus,
read=ContainerPermission.Read in container_def.permissions, read=ContainerPermission.Read in container_def.permissions,
write=ContainerPermission.Write in container_def.permissions, write=ContainerPermission.Write in container_def.permissions,
add=ContainerPermission.Add in container_def.permissions, add=ContainerPermission.Add in container_def.permissions,

View File

@ -18,7 +18,7 @@ from onefuzztypes.webhooks import (
WebhookEventTaskStopped, WebhookEventTaskStopped,
) )
from ..azure.creds import get_fuzz_storage from ..azure.containers import StorageType
from ..azure.image import get_os from ..azure.image import get_os
from ..azure.queue import create_queue, delete_queue from ..azure.queue import create_queue, delete_queue
from ..orm import MappingIntStrAny, ORMMixin, QueryFilter from ..orm import MappingIntStrAny, ORMMixin, QueryFilter
@ -123,7 +123,7 @@ class Task(BASE_TASK, ORMMixin):
} }
def init(self) -> None: def init(self) -> None:
create_queue(self.task_id, account_id=get_fuzz_storage()) create_queue(self.task_id, StorageType.corpus)
self.state = TaskState.waiting self.state = TaskState.waiting
self.save() self.save()
@ -132,7 +132,7 @@ class Task(BASE_TASK, ORMMixin):
logging.info("stopping task: %s:%s", self.job_id, self.task_id) logging.info("stopping task: %s:%s", self.job_id, self.task_id)
ProxyForward.remove_forward(self.task_id) ProxyForward.remove_forward(self.task_id)
delete_queue(str(self.task_id), account_id=get_fuzz_storage()) delete_queue(str(self.task_id), StorageType.corpus)
Node.stop_task(self.task_id) Node.stop_task(self.task_id)
self.state = TaskState.stopped self.state = TaskState.stopped
self.save() self.save()

View File

@ -10,8 +10,12 @@ from uuid import UUID
from onefuzztypes.enums import OS, PoolState, TaskState from onefuzztypes.enums import OS, PoolState, TaskState
from onefuzztypes.models import WorkSet, WorkUnit from onefuzztypes.models import WorkSet, WorkUnit
from ..azure.containers import blob_exists, get_container_sas_url, save_blob from ..azure.containers import (
from ..azure.creds import get_func_storage StorageType,
blob_exists,
get_container_sas_url,
save_blob,
)
from ..pools import Pool from ..pools import Pool
from .config import build_task_config, get_setup_container from .config import build_task_config, get_setup_container
from .main import Task from .main import Task
@ -60,20 +64,26 @@ def schedule_tasks() -> None:
agent_config = build_task_config(task.job_id, task.task_id, task.config) agent_config = build_task_config(task.job_id, task.task_id, task.config)
setup_container = get_setup_container(task.config) setup_container = get_setup_container(task.config)
setup_url = get_container_sas_url(setup_container, read=True, list=True) setup_url = get_container_sas_url(
setup_container, StorageType.corpus, read=True, list=True
)
setup_script = None setup_script = None
if task.os == OS.windows and blob_exists(setup_container, "setup.ps1"): if task.os == OS.windows and blob_exists(
setup_container, "setup.ps1", StorageType.corpus
):
setup_script = "setup.ps1" setup_script = "setup.ps1"
if task.os == OS.linux and blob_exists(setup_container, "setup.sh"): if task.os == OS.linux and blob_exists(
setup_container, "setup.sh", StorageType.corpus
):
setup_script = "setup.sh" setup_script = "setup.sh"
save_blob( save_blob(
"task-configs", "task-configs",
"%s/config.json" % task.task_id, "%s/config.json" % task.task_id,
agent_config.json(exclude_none=True), agent_config.json(exclude_none=True),
account_id=get_func_storage(), StorageType.config,
) )
reboot = False reboot = False
count = 1 count = 1

View File

@ -10,7 +10,7 @@ from msrestazure.azure_exceptions import CloudError
from onefuzztypes.enums import UpdateType from onefuzztypes.enums import UpdateType
from pydantic import BaseModel from pydantic import BaseModel
from .azure.creds import get_func_storage from .azure.containers import StorageType
from .azure.queue import queue_object from .azure.queue import queue_object
@ -46,7 +46,7 @@ def queue_update(
if not queue_object( if not queue_object(
"update-queue", "update-queue",
update, update,
account_id=get_func_storage(), StorageType.config,
visibility_timeout=visibility_timeout, visibility_timeout=visibility_timeout,
): ):
logging.error("unable to queue update") logging.error("unable to queue update")

View File

@ -27,7 +27,7 @@ from onefuzztypes.webhooks import WebhookMessageLog as BASE_WEBHOOK_MESSAGE_LOG
from pydantic import BaseModel from pydantic import BaseModel
from .__version__ import __version__ from .__version__ import __version__
from .azure.creds import get_func_storage from .azure.containers import StorageType
from .azure.queue import queue_object from .azure.queue import queue_object
from .orm import ORMMixin from .orm import ORMMixin
@ -135,8 +135,8 @@ class WebhookMessageLog(BASE_WEBHOOK_MESSAGE_LOG, ORMMixin):
queue_object( queue_object(
"webhooks", "webhooks",
obj, obj,
StorageType.config,
visibility_timeout=visibility_timeout, visibility_timeout=visibility_timeout,
account_id=get_func_storage(),
) )

View File

@ -12,9 +12,9 @@ from onefuzztypes.models import AgentConfig, Error
from onefuzztypes.requests import PoolCreate, PoolSearch, PoolStop from onefuzztypes.requests import PoolCreate, PoolSearch, PoolStop
from onefuzztypes.responses import BoolResult from onefuzztypes.responses import BoolResult
from ..onefuzzlib.azure.containers import StorageType
from ..onefuzzlib.azure.creds import ( from ..onefuzzlib.azure.creds import (
get_base_region, get_base_region,
get_func_storage,
get_instance_id, get_instance_id,
get_instance_url, get_instance_url,
get_regions, get_regions,
@ -33,7 +33,7 @@ def set_config(pool: Pool) -> Pool:
telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"),
heartbeat_queue=get_queue_sas( heartbeat_queue=get_queue_sas(
"node-heartbeat", "node-heartbeat",
account_id=get_func_storage(), StorageType.config,
add=True, add=True,
), ),
instance_id=get_instance_id(), instance_id=get_instance_id(),