Use Storage Account types, rather than account_id (#320)

We need to move to supporting data sharding.

One of the steps towards that is stop passing around `account_id`, rather we need to specify the type of storage we need.
This commit is contained in:
bmc-msft
2020-11-18 09:06:14 -05:00
committed by GitHub
parent 52eca33237
commit e47e89609a
18 changed files with 205 additions and 139 deletions

View File

@ -6,37 +6,61 @@
import datetime
import os
import urllib.parse
from typing import Dict, Optional, Union, cast
from enum import Enum
from typing import Any, Dict, Optional, Union, cast
from azure.common import AzureHttpError, AzureMissingResourceHttpError
from azure.storage.blob import BlobPermissions, ContainerPermissions
from memoization import cached
from .creds import get_blob_service
from .creds import get_blob_service, get_func_storage, get_fuzz_storage
class StorageType(Enum):
corpus = "corpus"
config = "config"
def get_account_id_by_type(storage_type: StorageType) -> str:
if storage_type == StorageType.corpus:
account_id = get_fuzz_storage()
elif storage_type == StorageType.config:
account_id = get_func_storage()
else:
raise NotImplementedError
return account_id
@cached(ttl=5)
def container_exists(name: str, account_id: Optional[str] = None) -> bool:
def get_blob_service_by_type(storage_type: StorageType) -> Any:
account_id = get_account_id_by_type(storage_type)
return get_blob_service(account_id)
@cached(ttl=5)
def container_exists(name: str, storage_type: StorageType) -> bool:
try:
get_blob_service(account_id).get_container_properties(name)
get_blob_service_by_type(storage_type).get_container_properties(name)
return True
except AzureHttpError:
return False
def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]:
def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]:
return {
x.name: x.metadata
for x in get_blob_service(account_id).list_containers(include_metadata=True)
for x in get_blob_service_by_type(storage_type).list_containers(
include_metadata=True
)
if not x.name.startswith("$")
}
def get_container_metadata(
name: str, account_id: Optional[str] = None
name: str, storage_type: StorageType
) -> Optional[Dict[str, str]]:
try:
result = get_blob_service(account_id).get_container_metadata(name)
result = get_blob_service_by_type(storage_type).get_container_metadata(name)
return cast(Dict[str, str], result)
except AzureHttpError:
pass
@ -44,22 +68,29 @@ def get_container_metadata(
def create_container(
name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None
name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]]
) -> Optional[str]:
try:
get_blob_service(account_id).create_container(name, metadata=metadata)
get_blob_service_by_type(storage_type).create_container(name, metadata=metadata)
except AzureHttpError:
# azure storage already logs errors
return None
return get_container_sas_url(
name, read=True, add=True, create=True, write=True, delete=True, list=True
name,
storage_type,
read=True,
add=True,
create=True,
write=True,
delete=True,
list=True,
)
def delete_container(name: str, account_id: Optional[str] = None) -> bool:
def delete_container(name: str, storage_type: StorageType) -> bool:
try:
return bool(get_blob_service(account_id).delete_container(name))
return bool(get_blob_service_by_type(storage_type).delete_container(name))
except AzureHttpError:
# azure storage already logs errors
return False
@ -67,7 +98,8 @@ def delete_container(name: str, account_id: Optional[str] = None) -> bool:
def get_container_sas_url(
container: str,
account_id: Optional[str] = None,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
@ -75,7 +107,7 @@ def get_container_sas_url(
delete: bool = False,
list: bool = False,
) -> str:
service = get_blob_service(account_id)
service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
permission = ContainerPermissions(read, add, create, write, delete, list)
@ -91,7 +123,8 @@ def get_container_sas_url(
def get_file_sas_url(
container: str,
name: str,
account_id: Optional[str] = None,
storage_type: StorageType,
*,
read: bool = False,
add: bool = False,
create: bool = False,
@ -102,7 +135,7 @@ def get_file_sas_url(
hours: int = 0,
minutes: int = 0,
) -> str:
service = get_blob_service(account_id)
service = get_blob_service_by_type(storage_type)
expiry = datetime.datetime.utcnow() + datetime.timedelta(
days=days, hours=hours, minutes=minutes
)
@ -117,9 +150,9 @@ def get_file_sas_url(
def save_blob(
container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None
container: str, name: str, data: Union[str, bytes], storage_type: StorageType
) -> None:
service = get_blob_service(account_id)
service = get_blob_service_by_type(storage_type)
service.create_container(container)
if isinstance(data, str):
service.create_blob_from_text(container, name, data)
@ -127,10 +160,8 @@ def save_blob(
service.create_blob_from_bytes(container, name, data)
def get_blob(
container: str, name: str, account_id: Optional[str] = None
) -> Optional[bytes]:
service = get_blob_service(account_id)
def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]:
service = get_blob_service_by_type(storage_type)
try:
blob = service.get_blob_to_bytes(container, name).content
return cast(bytes, blob)
@ -138,8 +169,8 @@ def get_blob(
return None
def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
def blob_exists(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type)
try:
service.get_blob_properties(container, name)
return True
@ -147,8 +178,8 @@ def blob_exists(container: str, name: str, account_id: Optional[str] = None) ->
return False
def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool:
service = get_blob_service(account_id)
def delete_blob(container: str, name: str, storage_type: StorageType) -> bool:
service = get_blob_service_by_type(storage_type)
try:
service.delete_blob(container, name)
return True

View File

@ -87,12 +87,12 @@ def get_insights_appid() -> str:
return os.environ["APPINSIGHTS_APPID"]
@cached
# @cached
def get_fuzz_storage() -> str:
return os.environ["ONEFUZZ_DATA_STORAGE"]
@cached
# @cached
def get_func_storage() -> str:
return os.environ["ONEFUZZ_FUNC_STORAGE"]
@ -109,9 +109,9 @@ def get_instance_url() -> str:
@cached
def get_instance_id() -> UUID:
from .containers import get_blob
from .containers import StorageType, get_blob
blob = get_blob("base-config", "instance_id", account_id=get_func_storage())
blob = get_blob("base-config", "instance_id", StorageType.config)
if blob is None:
raise Exception("missing instance_id")
return UUID(blob.decode())

View File

@ -19,6 +19,7 @@ from azure.storage.queue import (
from memoization import cached
from pydantic import BaseModel
from .containers import StorageType, get_account_id_by_type
from .creds import get_storage_account_name_key
QueueNameType = Union[str, UUID]
@ -27,7 +28,8 @@ DEFAULT_TTL = -1
@cached(ttl=60)
def get_queue_client(account_id: str) -> QueueServiceClient:
def get_queue_client(storage_type: StorageType) -> QueueServiceClient:
account_id = get_account_id_by_type(storage_type)
logging.debug("getting blob container (account_id: %s)", account_id)
name, key = get_storage_account_name_key(account_id)
account_url = "https://%s.queue.core.windows.net" % name
@ -41,13 +43,14 @@ def get_queue_client(account_id: str) -> QueueServiceClient:
@cached(ttl=60)
def get_queue_sas(
queue: QueueNameType,
storage_type: StorageType,
*,
account_id: str,
read: bool = False,
add: bool = False,
update: bool = False,
process: bool = False,
) -> str:
account_id = get_account_id_by_type(storage_type)
logging.debug("getting queue sas %s (account_id: %s)", queue, account_id)
name, key = get_storage_account_name_key(account_id)
expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30)
@ -67,31 +70,33 @@ def get_queue_sas(
@cached(ttl=60)
def create_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
def create_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(storage_type)
try:
client.create_queue(str(name))
except ResourceExistsError:
pass
def delete_queue(name: QueueNameType, *, account_id: str) -> None:
client = get_queue_client(account_id)
def delete_queue(name: QueueNameType, storage_type: StorageType) -> None:
client = get_queue_client(storage_type)
queues = client.list_queues()
if str(name) in [x["name"] for x in queues]:
client.delete_queue(name)
def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]:
client = get_queue_client(account_id)
def get_queue(
name: QueueNameType, storage_type: StorageType
) -> Optional[QueueServiceClient]:
client = get_queue_client(storage_type)
try:
return client.get_queue_client(str(name))
except ResourceNotFoundError:
return None
def clear_queue(name: QueueNameType, *, account_id: str) -> None:
queue = get_queue(name, account_id=account_id)
def clear_queue(name: QueueNameType, storage_type: StorageType) -> None:
queue = get_queue(name, storage_type)
if queue:
try:
queue.clear_messages()
@ -102,12 +107,12 @@ def clear_queue(name: QueueNameType, *, account_id: str) -> None:
def send_message(
name: QueueNameType,
message: bytes,
storage_type: StorageType,
*,
account_id: str,
visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL,
) -> None:
queue = get_queue(name, account_id=account_id)
queue = get_queue(name, storage_type)
if queue:
try:
queue.send_message(
@ -119,9 +124,8 @@ def send_message(
pass
def remove_first_message(name: QueueNameType, *, account_id: str) -> bool:
create_queue(name, account_id=account_id)
queue = get_queue(name, account_id=account_id)
def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool:
queue = get_queue(name, storage_type)
if queue:
try:
for message in queue.receive_messages():
@ -143,8 +147,8 @@ MAX_PEEK_SIZE = 32
# https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient
def peek_queue(
name: QueueNameType,
storage_type: StorageType,
*,
account_id: str,
object_type: Type[A],
max_messages: int = MAX_PEEK_SIZE,
) -> List[A]:
@ -154,7 +158,7 @@ def peek_queue(
if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE:
raise ValueError("invalid max messages: %s" % max_messages)
queue = get_queue(name, account_id=account_id)
queue = get_queue(name, storage_type)
if not queue:
return result
@ -168,12 +172,12 @@ def peek_queue(
def queue_object(
name: QueueNameType,
message: BaseModel,
storage_type: StorageType,
*,
account_id: str,
visibility_timeout: Optional[int] = None,
time_to_live: int = DEFAULT_TTL,
) -> bool:
queue = get_queue(name, account_id=account_id)
queue = get_queue(name, storage_type)
if not queue:
raise Exception("unable to queue object, no such queue: %s" % queue)